Generative Adversarial Networks

This is an introductory blog post about Generative Adversarial Networks with a TensorFlow tutorial.

Generative Adversarial Networks (GAN) pioneered by Ian Goodfellow are a set of algorithms that allow us to generate data. A good chunk of research in deep learning has been in supervised learning where we classify and detect objects in an image. GAN on the other hand is an unsupervised technique allowing us to generate data similar to the dataset we trained them on. Imagine having a dataset of face images. GANs will generate face images just like those in the dataset (but not exactly same) without any labeled data. Check out the result of one of the GANs below.

Boundary Equilibrium Generative Adversarial Network

Amazing, eh? Lately, GANs have taken over the deep learning community by a storm. After all, it is something really powerful. This entire paradigm is exciting because it allows us to create data giving us a wee bit more insight into neural networks. As of now, there are scores of GANs available in the literature. But this blog post is going to be confined to discussing the Generative Adversarial Networks.

What I cannot create, I do not understand.

Okay, that was a loose Feynman reference. The idea is that if we can generate data, we can have an understanding of how these networks work internally (although we are still far away from that).

But what are GANs exactly? And how do they work?

GANs are a pair of neural networks comprising of a generator and a discriminator competing against each other in a zero sum game. I’ll provide a typical analogy given by experts in this field. Consider a counterfeiter (generator) trying to create fake money and police (discriminator) trying to identify whether the money is fake or real. These two players are playing against each other. In order to win, counterfeiter must learn to create fake money that looks as real as possible while police on the other hand must become really good at distinguishing fake money from real money.

GANs in Practice

In practice each of the generator and discriminator is modeled as a neural network, simplest one being a fully connected neural network. The generator takes noise vector as input and output the target data (e.g. an image). Discriminator on the other hand takes the generated data and real dataset data as input and classifies it as fake or real. In theory, overtime, generator becomes really good at generating realistic data while discriminator becomes good at identifying fake from real.

In practice this hardly happens. GANs are known to be notoriously hard to train. While some people have questioned the fully connected neural network architecture used in the first publication, few others have pointed out probable errors in loss formulation. Several publications subsequently appeared on Arxiv trying to solve different issues such as Wasserstein GAN and Boundary Equilibrium GAN.

Implementing Generative Adversarial Networks

We’ll implement the simplest generative adversarial network in TensorFlow to get the intuition behind the idea. Let’s begin with coding up the generator and discriminator networks.


import tensorflow as tf

'''weight and bias functions'''
def weight(shape, name):
    return tf.Variable(tf.truncated_normal(shape=(shape), mean=0., stddev=0.01), dtype=tf.float32, name=name)

def bias(shape, name):
    return tf.Variable(tf.zeros(shape), dtype=tf.float32, name=name)

'''Sampling function to generate noise which will be the input to generator'''
def sample_z(shape):
    return np.random.normal(0., 1., shape)

'''defining the weights and biases for generator'''
gW1 = weight([100,128],'gW1')
gb1 = bias([128], 'gb1')
gW2 = weight([128, 784], 'gW2')
gb2 = bias([784], 'gb2')
theta_G = [gW1, gb1, gW2, gb2]


def generator(noise):
    #computational graph for generator
    h1 = tf.nn.relu(tf.matmul(noise, gW1) + gb1)
    gen_image = tf.nn.sigmoid(tf.matmul(h1, gW2) + gb2)
    return gen_image

'''defining the weights and biases for discriminator'''
dW1 = weight([784, 128], 'dW1')
db1 = bias([128], 'db1')
dW2 = weight([128, 1], 'dW2')
db2 = bias([1], 'db2')
theta_D = [dW1, db1, dW2, db2]

def critic(X):
    #computation graph for discriminator
    h1 = tf.nn.relu(tf.matmul(X, dW1) + db1)
    logit = tf.matmul(h1, dW2) + db2
    logit_prob = tf.nn.sigmoid(logit)
    return logit_prob

Next we’ll define the two placeholders. One placeholder is for noise that is an input to the generator and second is for real dataset that is an input to the discriminator.


'''Placeholders for noise vector and true images'''
noise = tf.placeholder(shape=[None, 100], dtype=tf.float32, name='input_noise')
true_images = tf.placeholder(shape=[None, 784], dtype=tf.float32, name='true_images')

Both generator and discriminator will have their own loss functions which need to be optimized. Loss functions are as follows. where D, G are the discriminator and generator respectively. z, X are the noise vector and true data.


G_out = generator(z)
D_fake = critic(G_out)
D_true = critic(true_images)
critic_loss = -tf.reduce_mean(tf.log(D_true) + tf.log(1 - D_fake))
generator_loss = tf.reduce_mean(tf.log(1 - D_fake))

D_opt = tf.train.AdamOptimizer().minimize(critic_loss, var_list=theta_D)
G_opt = tf.train.AdamOptimizer().minimize(generator_loss, var_list=theta_G)

The training procedure is straightforward for this network.


sess = tf.InteractiveSession()
init_op = tf.global_variables_initializer().run()
G_history = []
D_history = []


for i in range(100000):
    generated_image = sess.run([G_out], feed_dict={z:sample_z([1,100])})
    _X,_ = mnist_dataset.train.next_batch(32)
    _,D_loss = sess.run([D_opt, critic_loss], feed_dict={true_images: _X, z:sample_z((32,100))})
    _,G_loss = sess.run([G_opt, generator_loss], feed_dict={z:sample_z((32,100))})
    G_history.append(G_loss)
    D_history.append(D_loss)
    if i%1000 == 0:
        print ("G_loss: {}, D_loss: {}".format(G_loss, D_loss))
        plt.imshow(generated_image[0].reshape(28,28))
        plt.show()

References:

  1. https://arxiv.org/pdf/1701.00160.pdf
  2. https://arxiv.org/abs/1406.2661

Shashwat Verma

Software Engineer at Singapore-MIT Alliance for Research and Technology working on self driving vehicles.

Singapore