Relativistic AnimeGAN

computervision
deeplearning
keras
python
tensorflow
Generating Anime Faces using Relativistic GAN
Author
Published

July 22, 2019

::: {.cell _cell_guid=‘b1076dfc-b9ad-4769-8c92-a6c4dae69d19’ _uuid=‘8f2839f25d086af736a60e9eeb907d3b93b6e0e5’ execution_count=1}

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch, os
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm_notebook as tqdm

:::

::: {.cell _cell_guid=‘79c7e3d0-c299-4dcb-8224-4455121ee9b0’ _uuid=‘d629ff2d2480ee46fbb7e2d37f6b5fab8052498a’ execution_count=2}

class Generator(nn.Module):
    def __init__(self, nz=128, channels=3):
        super(Generator, self).__init__()
        
        self.nz = nz
        self.channels = channels
        
        def convlayer(n_input, n_output, k_size = 4, stride = 2, padding = 0):
            block = [
                nn.ConvTranspose2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False),
                nn.BatchNorm2d(n_output),
                nn.ReLU(inplace=True),
            ]
            return block

        self.model = nn.Sequential(
            *convlayer(self.nz, 1024, 4, 1, 0),
            *convlayer(1024, 512, 4, 2, 1),
            *convlayer(512, 256, 4, 2, 1),
            *convlayer(256, 128, 4, 2, 1),
            *convlayer(128, 64, 4, 2, 1),
            nn.ConvTranspose2d(64, self.channels, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, z):
        z = z.view(-1, self.nz, 1, 1)
        img = self.model(z)
        return img

:::

class Discriminator(nn.Module):
    def __init__(self, channels=3):
        super(Discriminator, self).__init__()
        
        self.channels = channels

        def convlayer(n_input, n_output, k_size = 4, stride = 2, padding = 0, bn = False):
            block = [nn.Conv2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False)]
            if bn:
                block.append(nn.BatchNorm2d(n_output))
            block.append(nn.LeakyReLU(0.2, inplace=True))
            return block

        self.model = nn.Sequential(
            *convlayer(self.channels, 32, 4, 2, 1),
            *convlayer(32, 64, 4, 2, 1),
            *convlayer(64, 128, 4, 2, 1, bn = True),
            *convlayer(128, 256, 4, 2, 1, bn = True),
            nn.Conv2d(256, 1, 4, 1, 0, bias = False),
        )

    def forward(self, imgs):
        out = self.model(imgs)
        return out.view(-1, 1)
batch_size = 32
lr = 0.001
beta1 = 0.5
epochs = 300

real_label = 0.5
fake_label = 0
nz = 128

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class AnimeFacesDataset(Dataset):
    def __init__(self, img_dir, transform1=None, transform2=None):
    
        self.img_dir = img_dir
        self.img_names = os.listdir(img_dir)
        self.transform1 = transform1
        self.transform2 = transform2
        
        self.imgs = []
        for img_name in self.img_names:
            img = Image.open(os.path.join(img_dir, img_name))
            
            if self.transform1 is not None:
                img = self.transform1(img)
                
            self.imgs.append(img)

    def __getitem__(self, index):
        img = self.imgs[index]
        
        if self.transform2 is not None:
            img = self.transform2(img)
        
        return img

    def __len__(self):
        return len(self.imgs)
transform1 = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64)
])

random_transforms = [transforms.RandomRotation(degrees = 5)]

transform2 = transforms.Compose([
    transforms.RandomHorizontalFlip(p = 0.5),
    transforms.RandomApply(random_transforms, p = 0.3),
    transforms.ToTensor(),
    transforms.Normalize(
        (0.5, 0.5, 0.5),
        (0.5, 0.5, 0.5)
    )
])
                                 
train_dataset = AnimeFacesDataset(
    img_dir = '../input/data/data/',
    transform1 = transform1,
    transform2 = transform2
)

train_loader = DataLoader(
    dataset = train_dataset,
    batch_size = batch_size,
    shuffle = True,
    num_workers = 4
)
                                           
imgs = next(iter(train_loader))
imgs = imgs.numpy().transpose(0, 2, 3, 1)
fig = plt.figure(figsize = (25, 16))
for ii, img in enumerate(imgs):
    ax = fig.add_subplot(4, 8, ii + 1, xticks = [], yticks = [])
    plt.imshow((img + 1) / 2)

netG = Generator(nz).to(device)
netD = Discriminator().to(device)

criterion = nn.BCELoss()

optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

fixed_noise = torch.randn(25, nz, 1, 1, device=device)
def show_generated_img():
    noise = torch.randn(1, nz, 1, 1, device=device)
    gen_image = netG(noise).to("cpu").clone().detach().squeeze(0)
    gen_image = gen_image.numpy().transpose(1, 2, 0)
    plt.imshow((gen_image+1)/2)
    plt.show()
for epoch in range(epochs):
    for ii, real_images in tqdm(enumerate(train_loader), total=len(train_loader)):
        netD.zero_grad()
        real_images = real_images.to(device)
        batch_size = real_images.size(0)
        labels = torch.full((batch_size, 1), real_label, device=device)
        outputR = netD(real_images)
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake = netG(noise)
        outputF = netD(fake.detach())
        errD = (torch.mean((outputR - torch.mean(outputF) - labels) ** 2) + 
                torch.mean((outputF - torch.mean(outputR) + labels) ** 2))/2
        errD.backward(retain_graph=True)
        optimizerD.step()
        netG.zero_grad()
        outputF = netD(fake)   
        errG = (torch.mean((outputR - torch.mean(outputF) + labels) ** 2) +
                torch.mean((outputF - torch.mean(outputR) - labels) ** 2))/2
        errG.backward()
        optimizerG.step()
        
        if (ii+1) % (len(train_loader)//2) == 0:
            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f'
                  % (epoch + 1, epochs, ii+1, len(train_loader),
                     errD.item(), errG.item()))
    show_generated_img()
[1/300][337/674] Loss_D: 0.2228 Loss_G: 1.2557
[1/300][674/674] Loss_D: 0.4220 Loss_G: 2.8054

[2/300][337/674] Loss_D: 0.1080 Loss_G: 1.9196
[2/300][674/674] Loss_D: 0.1211 Loss_G: 1.6818

[3/300][337/674] Loss_D: 0.1240 Loss_G: 1.4734
[3/300][674/674] Loss_D: 0.0977 Loss_G: 1.9045

[4/300][337/674] Loss_D: 0.0401 Loss_G: 1.3443
[4/300][674/674] Loss_D: 0.0962 Loss_G: 1.3753

[5/300][337/674] Loss_D: 0.0621 Loss_G: 1.3746
[5/300][674/674] Loss_D: 0.1293 Loss_G: 2.1275

[6/300][337/674] Loss_D: 0.0967 Loss_G: 1.2415
[6/300][674/674] Loss_D: 0.0916 Loss_G: 1.2541

[7/300][337/674] Loss_D: 0.0688 Loss_G: 1.3309
[7/300][674/674] Loss_D: 0.0830 Loss_G: 1.0978

[8/300][337/674] Loss_D: 0.1210 Loss_G: 1.5478
[8/300][674/674] Loss_D: 0.1955 Loss_G: 1.7606

[9/300][337/674] Loss_D: 0.1035 Loss_G: 1.9064
[9/300][674/674] Loss_D: 0.1157 Loss_G: 0.9881

[10/300][337/674] Loss_D: 0.1077 Loss_G: 1.5257
[10/300][674/674] Loss_D: 0.1450 Loss_G: 2.0090

[11/300][337/674] Loss_D: 0.1498 Loss_G: 1.5121
[11/300][674/674] Loss_D: 0.3383 Loss_G: 2.4605

[12/300][337/674] Loss_D: 0.0704 Loss_G: 1.3353
[12/300][674/674] Loss_D: 0.1374 Loss_G: 1.5307

[13/300][337/674] Loss_D: 0.0670 Loss_G: 1.2464
[13/300][674/674] Loss_D: 0.1050 Loss_G: 1.3630

[14/300][337/674] Loss_D: 0.0757 Loss_G: 1.1511
[14/300][674/674] Loss_D: 0.3176 Loss_G: 1.4190

[15/300][337/674] Loss_D: 0.1194 Loss_G: 1.6059
[15/300][674/674] Loss_D: 0.0716 Loss_G: 1.3875

[16/300][337/674] Loss_D: 0.0843 Loss_G: 1.7008
[16/300][674/674] Loss_D: 0.0603 Loss_G: 0.9027

[17/300][337/674] Loss_D: 0.0414 Loss_G: 0.9865
[17/300][674/674] Loss_D: 0.2900 Loss_G: 1.0561

[18/300][337/674] Loss_D: 0.0567 Loss_G: 1.0921
[18/300][674/674] Loss_D: 0.0596 Loss_G: 1.2612

[19/300][337/674] Loss_D: 0.0700 Loss_G: 1.2655
[19/300][674/674] Loss_D: 0.0848 Loss_G: 1.5331

[20/300][337/674] Loss_D: 0.0707 Loss_G: 1.1320
[20/300][674/674] Loss_D: 0.0767 Loss_G: 1.6901

[21/300][337/674] Loss_D: 0.1091 Loss_G: 1.2790
[21/300][674/674] Loss_D: 0.1408 Loss_G: 1.6403

[22/300][337/674] Loss_D: 0.0439 Loss_G: 1.0915
[22/300][674/674] Loss_D: 0.0587 Loss_G: 1.3391

[23/300][337/674] Loss_D: 0.0383 Loss_G: 1.0457
[23/300][674/674] Loss_D: 0.0360 Loss_G: 1.2573

[24/300][337/674] Loss_D: 0.0391 Loss_G: 1.2543
[24/300][674/674] Loss_D: 0.0529 Loss_G: 1.3485

[25/300][337/674] Loss_D: 0.0862 Loss_G: 1.1176

gen_z = torch.randn(32, nz, 1, 1, device=device)
gen_images = (netG(gen_z).to("cpu").clone().detach() + 1)/2
gen_images = gen_images.numpy().transpose(0, 2, 3, 1)
fig = plt.figure(figsize=(25, 16))
for ii, img in enumerate(gen_images):
    ax = fig.add_subplot(4, 8, ii + 1, xticks=[], yticks=[])
    plt.imshow(img)