⏱️ 80 min

Generative Adversarial Networks (GANs)

Generate synthetic images and data with GANs

Understanding GANs

GANs consist of two neural networks competing against each other in a zero-sum game: **The Setup:** 1. **Generator (G)**: Creates fake data from random noise 2. **Discriminator (D)**: Tries to distinguish real from fake data **The Game:** - Generator tries to fool the discriminator - Discriminator tries to correctly identify fakes - Both improve through competition - Eventually: Generator creates realistic data **Applications:** - Image generation (faces, art, scenes) - Image-to-image translation (day→night, sketch→photo) - Super-resolution (enhance image quality) - Data augmentation - Style transfer

Simple GAN Implementation

Build a basic GAN for generating MNIST digits:

python
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# Generator Network
class Generator(nn.Module):
    def __init__(self, latent_dim=100, img_shape=(1, 28, 28)):
        super(Generator, self).__init__()
        self.img_shape = img_shape
        
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()  # Output in [-1, 1]
        )
    
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img

# Discriminator Network
class Discriminator(nn.Module):
    def __init__(self, img_shape=(1, 28, 28)):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()  # Output probability [0, 1]
        )
    
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

# Hyperparameters
latent_dim = 100
img_shape = (1, 28, 28)
lr = 0.0002
batch_size = 64

# Initialize models
generator = Generator(latent_dim, img_shape)
discriminator = Discriminator(img_shape)

# Loss function
adversarial_loss = nn.BCELoss()

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

print(f"Generator parameters: {sum(p.numel() for p in generator.parameters()):,}")
print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}")

# Training loop (one batch example)
def train_step(real_imgs):
    batch_size = real_imgs.size(0)
    
    # Labels
    valid = torch.ones(batch_size, 1)   # Real images
    fake = torch.zeros(batch_size, 1)   # Fake images
    
    # ---------------------
    #  Train Discriminator
    # ---------------------
    optimizer_D.zero_grad()
    
    # Loss on real images
    real_loss = adversarial_loss(discriminator(real_imgs), valid)
    
    # Generate fake images
    z = torch.randn(batch_size, latent_dim)
    gen_imgs = generator(z)
    
    # Loss on fake images
    fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
    
    # Total discriminator loss
    d_loss = (real_loss + fake_loss) / 2
    d_loss.backward()
    optimizer_D.step()
    
    # -----------------
    #  Train Generator
    # -----------------
    optimizer_G.zero_grad()
    
    # Generate images
    z = torch.randn(batch_size, latent_dim)
    gen_imgs = generator(z)
    
    # Generator wants discriminator to think images are real
    g_loss = adversarial_loss(discriminator(gen_imgs), valid)
    g_loss.backward()
    optimizer_G.step()
    
    return d_loss.item(), g_loss.item()

# Example training
print(f"\nSimulating training batch...")
dummy_real_imgs = torch.randn(batch_size, 1, 28, 28)
d_loss, g_loss = train_step(dummy_real_imgs)

print(f"Discriminator loss: {d_loss:.4f}")
print(f"Generator loss: {g_loss:.4f}")

# Generate sample images
generator.eval()
with torch.no_grad():
    z = torch.randn(16, latent_dim)
    generated_imgs = generator(z)

print(f"\nGenerated images shape: {generated_imgs.shape}")
print(f"Value range: [{generated_imgs.min():.2f}, {generated_imgs.max():.2f}]")
print(f"\n✓ GAN ready for training!")
Output:
Generator parameters: 1,493,760
Discriminator parameters: 533,249

Simulating training batch...
Discriminator loss: 0.6847
Generator loss: 0.7123

Generated images shape: torch.Size([16, 1, 28, 28])
Value range: [-0.98, 0.96]

✓ GAN ready for training!

DCGAN - Deep Convolutional GAN

Use convolutional layers for better image generation:

python
import torch.nn as nn

class DCGANGenerator(nn.Module):
    def __init__(self, latent_dim=100, channels=1):
        super(DCGANGenerator, self).__init__()
        
        self.main = nn.Sequential(
            # Input: latent_dim x 1 x 1
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # State: 512 x 4 x 4
            
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # State: 256 x 8 x 8
            
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # State: 128 x 16 x 16
            
            nn.ConvTranspose2d(128, channels, 4, 2, 1, bias=False),
            nn.Tanh()
            # Output: channels x 32 x 32
        )
    
    def forward(self, z):
        # Reshape z to 4D tensor
        z = z.view(z.size(0), z.size(1), 1, 1)
        return self.main(z)

class DCGANDiscriminator(nn.Module):
    def __init__(self, channels=1):
        super(DCGANDiscriminator, self).__init__()
        
        self.main = nn.Sequential(
            # Input: channels x 32 x 32
            nn.Conv2d(channels, 128, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # State: 128 x 16 x 16
            
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # State: 256 x 8 x 8
            
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # State: 512 x 4 x 4
            
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
            # Output: 1 x 1 x 1
        )
    
    def forward(self, img):
        output = self.main(img)
        return output.view(-1, 1)

# Create DCGAN models
latent_dim = 100
gen = DCGANGenerator(latent_dim)
disc = DCGANDiscriminator()

print("DCGAN Architecture:")
print(f"Generator parameters: {sum(p.numel() for p in gen.parameters()):,}")
print(f"Discriminator parameters: {sum(p.numel() for p in disc.parameters()):,}")

# Test forward pass
import torch
z = torch.randn(4, latent_dim)
fake_imgs = gen(z)
validity = disc(fake_imgs)

print(f"\nGenerated image shape: {fake_imgs.shape}")
print(f"Discriminator output shape: {validity.shape}")
print(f"\n✓ DCGAN ready for training on image datasets!")
Output:
DCGAN Architecture:
Generator parameters: 3,762,688
Discriminator parameters: 2,765,825

Generated image shape: torch.Size([4, 1, 32, 32])
Discriminator output shape: torch.Size([4, 1])

✓ DCGAN ready for training on image datasets!

GAN Training Tips

**Common Challenges:** 1. **Mode Collapse**: Generator produces limited variety - Solution: Use mini-batch discrimination, feature matching 2. **Training Instability**: Losses oscillate wildly - Solution: Lower learning rates, use label smoothing 3. **Vanishing Gradients**: Generator stops learning - Solution: Use Wasserstein GAN (WGAN) with gradient penalty **Best Practices:** - Use LeakyReLU in discriminator - Use BatchNorm in both networks (except output layers) - Use Adam optimizer with β1=0.5 - Train discriminator more than generator initially - Use label smoothing (0.9 instead of 1.0 for real labels) **Advanced GAN Variants:** - **StyleGAN**: High-quality image synthesis - **CycleGAN**: Unpaired image-to-image translation - **Pix2Pix**: Paired image translation - **BigGAN**: Large-scale image generation - **Progressive GAN**: Gradually increase resolution

Sharan Initiatives - Making a Difference Together