Vanilla GAN
# Import all the necessary libraries
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm import notebook
# Define our simple vanilla generator
class Generator(nn.Module):
    """
    Architecture
    ------------
    Latent Input: latent_shape
    Flattened
    Linear MLP(256, 512, 1024, prod(img_shape))
    Leaky Relu activation after every layer except last. (Important!)
    Tanh activation after last layer to normalize
    """
    def __init__(self, latent_shape, img_shape):
        super(Generator, self).__init__()
        self.img_shape = img_shape
        self.mlp = nn.Sequential(
            nn.Flatten(),
            nn.Linear(np.prod(latent_shape), 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, np.prod(img_shape)),
            nn.Tanh()
        )
    def forward(self, x):
        batch_size = x.shape[0]
        # reshape into a image
        return self.mlp(x).reshape(batch_size, 1, *self.img_shape)
# Define our simple vanilla discriminator
class Discriminator(nn.Module):
    """
    Architecture
    ------------
    Input Image: img_shape
    Flattened
    Linear MLP(1024, 512, 256, 1)
    Leaky Relu activation after every layer except last.
    Sigmoid activation after last layer to normalize in range 0 to 1
    """
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()
        self.mlp = nn.Sequential(
            nn.Flatten(),
            nn.Linear(np.prod(img_shape), 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.mlp(x)
# load our data
latent_shape = (28, 28)
img_shape = (28, 28)
batch_size = 64
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(root="./data", train = True, download=True, transform=transform)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # for gpu usage if possible
generator = Generator(latent_shape, img_shape)
discriminator = Discriminator(img_shape)
gen_optim = torch.optim.Adam(generator.parameters(), lr=2e-4)
disc_optim = torch.optim.Adam(discriminator.parameters(), lr=2e-4)
# .to(device) moves the networks / models to that device, which is either CPU or the GPU depending on what was detected
# if moved to GPU, then the networks can make use of the GPU for computations which is much faster!
generator = generator.to(device)
discriminator = discriminator.to(device)
def train(generator, discriminator, generator_optim: torch.optim, discriminator_optim: torch.optim, epochs=100):
    adversarial_loss = torch.nn.BCELoss()
    
    for epoch in range(1, epochs+1):
        print("Epoch {}".format(epoch))
        avg_g_loss = 0
        avg_d_loss = 0
        
        # notebook.tqdm is a nice way of displaying progress on a jupyter or colab notebook while we loop over the data in train_dataloader
        pbar = notebook.tqdm(train_dataloader, total=len(train_dataloader))
        i = 0
        for data in pbar:
            i += 1
            real_images = data[0].to(device)
            ### Train Generator ###
            # .zero_grad() is important in PyTorch. Don't forget it. If you do, the optimizer won't work.
            generator_optim.zero_grad()
            
            latent_input = torch.randn((len(real_images), 1, *latent_shape)).to(device)
            fake_images = generator(latent_input)
            fake_res = discriminator(fake_images)
            
            # we penalize the generator for being unable to make the discrminator predict 1s for generated fake images
            generator_loss = adversarial_loss(fake_res, torch.ones_like(fake_res))
            # .backward() computes gradients for the loss function with respect to anything that is not detached
            generator_loss.backward()
            # .step() uses a optimizer to apply the gradients to the model parameters, updating the model to reduce the loss
            generator_optim.step()
            
            ### Train Discriminator ###
            discriminator_optim.zero_grad()
            
            real_res = discriminator(real_images)
            # .detach() removes fake_images variable from gradient computation, meaning our 
            # generator is not going to be updated when we use the optimizer
            fake_res = discriminator(fake_images.detach())
            # we penalize the discriminator for not predicting 1s for real images
            discriminator_real_loss = adversarial_loss(real_res, torch.ones_like(real_res))
            # we penalize the discriminator for not predicting 0s for generated, fake images
            discriminator_fake_loss = adversarial_loss(fake_res, torch.zeros_like(real_res))
            
            discriminator_loss = (discriminator_real_loss + discriminator_fake_loss) / 2
            
            discriminator_loss.backward()
            discriminator_optim.step()
            
            avg_g_loss += generator_loss.item()
            avg_d_loss += discriminator_loss.item()
            pbar.set_postfix({"G_loss": generator_loss.item(), "D_loss": discriminator_loss.item()})
        print("Avg G_loss {} - Avg D_loss {}".format(avg_g_loss / i, avg_d_loss / i))
# train our generator and discriminator
# Note: don't always expect loss to go down simultaneously for both models. They are competing against each other! So sometimes one model 
# may perform better than the other
train(generator=generator, discriminator=discriminator, generator_optim=gen_optim, discriminator_optim=disc_optim)
# test it out!
latent_input = torch.randn((batch_size, 1, *latent_shape))
test = generator(latent_input.to(device))
plt.imshow(test[0].reshape(28, 28).cpu().detach().numpy())