DCGAN

computervision
deeplearning
keras
python
tensorflow
Implementation of Deep Convolutional GAN using Keras and Tensorflow
Author
Published

July 20, 2019

Project Repository: https://github.com/soumik12345/Adventures-with-GANS

::: {.cell _cell_guid=‘b1076dfc-b9ad-4769-8c92-a6c4dae69d19’ _uuid=‘8f2839f25d086af736a60e9eeb907d3b93b6e0e5’ execution_count=1}

import warnings
warnings.filterwarnings('ignore')

:::

::: {.cell _cell_guid=‘79c7e3d0-c299-4dcb-8224-4455121ee9b0’ _uuid=‘d629ff2d2480ee46fbb7e2d37f6b5fab8052498a’ execution_count=2}

import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import *
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from keras.utils.vis_utils import model_to_dot
from IPython.display import SVG
from tqdm import tqdm

:::

IMAGE_WIDTH = 28
IMAGE_HEIGHT = 28
IMAGE_CHANNELS = 1
BATCH_SIZE = 128
LATENT_DIMENSION = 100
IMAGE_SHAPE = (IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_CHANNELS)
EPOCHS = 8000
def load_data():
    (x_train, _), (_, _) = mnist.load_data()
    x_train = x_train / 127.5 - 1.
    x_train = np.expand_dims(x_train, axis = 3)
    return x_train
x_train = load_data()
x_train.shape
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
(60000, 28, 28, 1)
def build_generator(latent_dimension, optimizer):
    generator = Sequential([
        Dense(256, input_dim = latent_dimension, activation = 'tanh'),
        Dense(128 * 7 * 7),
        BatchNormalization(),
        Activation('tanh'),
        Reshape((7, 7, 128)),
        UpSampling2D(size = (2, 2)),
        Conv2D(64, (5, 5), padding = 'same', activation = 'tanh'),
        UpSampling2D(size = (2, 2)),
        Conv2D(1, (5, 5), padding = 'same', activation = 'tanh')
    ])
    generator.compile(loss = 'binary_crossentropy', optimizer = optimizer)
    return generator
def build_discriminator(image_shape, optimizer):
    discriminator = Sequential([
        Conv2D(64, (5, 5), padding = 'same', input_shape = image_shape, activation = 'tanh'),
        MaxPooling2D(pool_size = (2, 2)),
        Conv2D(128, (5, 5), activation = 'tanh'),
        MaxPooling2D(pool_size = (2, 2)),
        Flatten(),
        Dense(1024, activation = 'tanh'),
        Dense(1, activation = 'sigmoid')
    ])
    discriminator.compile(loss = 'binary_crossentropy', optimizer = optimizer)
    return discriminator
def build_gan(generator, discriminator, latent_dimension, optimizer):
    discriminator.trainable = False
    gan_input = Input(shape = (latent_dimension, ))
    x = generator(gan_input)
    gan_output = discriminator(x)
    gan = Model(gan_input, gan_output, name = 'GAN')
    gan.compile(loss = 'binary_crossentropy', optimizer = optimizer, metrics = ['accuracy'])
    return gan
optimizer = Adam(0.0002, 0.5)
WARNING:tensorflow:From /opt/conda/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:435: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
generator = build_generator(LATENT_DIMENSION, optimizer)
generator.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 256)               25856     
_________________________________________________________________
dense_1 (Dense)              (None, 6272)              1611904   
_________________________________________________________________
batch_normalization_v1 (Batc (None, 6272)              25088     
_________________________________________________________________
activation (Activation)      (None, 6272)              0         
_________________________________________________________________
reshape (Reshape)            (None, 7, 7, 128)         0         
_________________________________________________________________
up_sampling2d (UpSampling2D) (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 14, 14, 64)        204864    
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 28, 28, 1)         1601      
=================================================================
Total params: 1,869,313
Trainable params: 1,856,769
Non-trainable params: 12,544
_________________________________________________________________
SVG(model_to_dot(generator, show_shapes = True, show_layer_names = True).create(prog = 'dot', format = 'svg'))

discriminator = build_discriminator(IMAGE_SHAPE, optimizer)
discriminator.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_2 (Conv2D)            (None, 28, 28, 64)        1664      
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 10, 10, 128)       204928    
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 5, 5, 128)         0         
_________________________________________________________________
flatten (Flatten)            (None, 3200)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 1024)              3277824   
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 1025      
=================================================================
Total params: 3,485,441
Trainable params: 3,485,441
Non-trainable params: 0
_________________________________________________________________
SVG(model_to_dot(discriminator, show_shapes = True, show_layer_names = True).create(prog = 'dot', format = 'svg'))

gan = build_gan(generator, discriminator, LATENT_DIMENSION, optimizer)
gan.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 100)               0         
_________________________________________________________________
sequential (Sequential)      (None, 28, 28, 1)         1869313   
_________________________________________________________________
sequential_1 (Sequential)    (None, 1)                 3485441   
=================================================================
Total params: 5,354,754
Trainable params: 1,856,769
Non-trainable params: 3,497,985
_________________________________________________________________
SVG(model_to_dot(gan, show_shapes = True, show_layer_names = True).create(prog = 'dot', format = 'svg'))

def plot_images(nrows, ncols, figsize, generator):
    fig, axes = plt.subplots(nrows = nrows, ncols = ncols, figsize = figsize)
    plt.setp(axes.flat, xticks = [], yticks = [])
    noise = np.random.normal(0, 1, (nrows * ncols, LATENT_DIMENSION))
    generated_images = generator.predict(noise).reshape(nrows * ncols, IMAGE_WIDTH, IMAGE_HEIGHT)
    for i, ax in enumerate(axes.flat):
        ax.imshow(generated_images[i], cmap = 'gray')
    plt.show()
generator_loss_history, discriminator_loss_history = [], []

for epoch in tqdm(range(1, EPOCHS + 1)):
    
    # Select a random batch of images from training data
    index = np.random.randint(0, x_train.shape[0], BATCH_SIZE)
    batch_images = x_train[index]
    
    # Adversarial Noise
    noise = np.random.normal(0, 1, (BATCH_SIZE, LATENT_DIMENSION))
    
    # Generate fake images
    generated_images = generator.predict(noise)
    
    # Construct batches of real and fake data
    x = np.concatenate([batch_images, generated_images])
    
    # Labels for training the discriminator
    y_discriminator = np.zeros(2 * BATCH_SIZE)
    y_discriminator[: BATCH_SIZE] = 0.9
    
    # train the discrimator to distinguish between fake data and real data
    discriminator.trainable = True
    discriminator_loss = discriminator.train_on_batch(x, y_discriminator)
    discriminator_loss_history.append(discriminator_loss)
    discriminator.trainable = False
    
    # Training the GAN
    generator_loss = gan.train_on_batch(noise, np.ones(BATCH_SIZE))
    generator_loss_history.append(generator_loss)
    
    if epoch % 1000 == 0:
        plot_images(1, 8, (16, 4), generator)
WARNING:tensorflow:From /opt/conda/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.

plot_images(2, 8, (16, 6), generator)

generator.save('./generator.h5')