Generative Adversarial Networks

Implementation of "Generative adversarial networks".

Introduction

Generative adversarial networks are well known as one of the most influential generative models. Generative model is one to replicate data using statistical model. This is formally equivalent to estimate a joint probability distribution for outputs and also inputs. Maximizing the joint likelihood directly could be one of the methods but it is unrealistic because computaional costs increase as the dimension of input space increases. Recently, a lot of methods to estimate generative models have been developed, for example, Restricted Boltzmann Machine, Variational Auto-Encoder, Generative Adversarial Networks and so on. This tutorial deals with generative adversarial networks(henceforth, reffered to as GAN). GAN was first published by Goodfellow et al.(2014)[1]. Then, several extensions and a lot of application methods have been proposed. For instance, high-resolution images can be generated using GAN[2] and GAN translate an image into a corresponding image, i.e. translating satellite photograph into map[3].

Required libraries

In [1]:
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.datasets import fetch_mldata
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
import time
import numpy as np
import renom as rm
from renom.cuda.cuda import set_cuda_active
from renom.optimizer import Adam
from renom.utility.initializer import Uniform

We strongly recommend using GPU because GAN needs strictly hyperparameter tuning and often deals with images.

In [2]:
set_cuda_active(True)

Downloading MNIST data

In [3]:
digits = load_digits(n_class=10)


data_path = "."
mnist = fetch_mldata('MNIST original', data_home=data_path)

Preprocessing and splitting data

In [4]:
X = mnist.data
y = mnist.target

npx = 28  # the number of pixels on a side of the image
ntest = 20000
ntrain = len(X) - ntest

X1 = X.reshape(len(X), 1, npx, npx) / 255
Y = OneHotEncoder().fit_transform(y.reshape(-1,1), 10).toarray()
Y = Y.astype(np.float32)
X_train, X_test, y_train, y_test = train_test_split(X1, Y, test_size=ntest)

Generative adversarial networks (GAN)

GAN has two networks: one is a generator and the other is a discriminator. The generator plays the role to disguise data like forgers. On the other hand, the discriminator's role is to distinguish actual data and fake data which are generated by the generator, like detectives. We consider that the generator and the discriminator play the following two player zero-sum game. In the game, the relationship between the two is adversarial but we can regard the relationship is cooperative from the point of learning generative model. Let D(\cdot) and G(\cdot) be discriminator and generator respectively. D(\mathbf{x}) is the probability function whether \mathbf{x} is real data, or not. Inputs of G are random sampleing number from uniform distribution or normal distribution. We assume the discriminator's payoff is the following value function.

\begin{equation} V(D, G) = \mathbb{E}_{\mathbf{x}\sim p_{data}(\mathbf{x})}[ \log D(\mathbf{x}) ] + \mathbb{E}_{\mathbf{z}\sim p_{\mathbf{z}}(\mathbf{z})}[ \log (1-D(G(\mathbf{z}))) ] \end{equation}

,where p_{data}(\mathbf{x}) and p_\mathbf{z}(\mathbf{z}) are probability density functions of the actual data and \mathbf{z} , respectively. On the other hand, we also assume the generator's payoff is -V(D,G) . This value function multiplied -1 is interpreted as a cross entropy loss function for discriminator networks. You can also use this value function for a loss function of generator networks. Each player tries to maximize its own payoff. The equilibrium of this game is given by the optimization problem,

\begin{equation} \min_G \max_D V(D, G). \end{equation}

Goodfellow et al.(2014)[1] showed that the distribution generated by the generator is equal to that of the actual data at the equilibrium. Then, D(\mathbf{x}) is equal to 1/2 for all \mathbf{x} . In other words, even the optimal discriminator cannot distinguish whether or not data generated by the optimal generator is fake. We explain the proof shortly below.

We first think about the discriminator as the discriminator move earlier than the generator in this setting. V(G,D) is maximized given any G . V(G,D) can be transformed

\begin{align} V(G,D) &= \int_{\mathbf{x}} p_{data}(\mathbf{x}) \log (D(\mathbf{x}))d\mathbf{x} + \int_{\mathbf{z}}p_\mathbf{z} (\mathbf{z}) \log (1 - D(G(\mathbf{z}))) d\mathbf{z} \\ &= \int_{\mathbf{x}} \left[ p_{data}(\mathbf{x}) \log (D(\mathbf{x})) + p_g (\mathbf{x}) \log (1 - D(\mathbf{x})) \right] d\mathbf{x} \end{align}

, where p_g(\mathbf{x}) is a probability density function of generated data on the space of \mathbf{x} . Suppose f(y) = a\log (y) + b\log(1-y) , the maximum of f(y) in [0,1] is \frac{a}{a+b} . Note that the range of D(\mathbf{x}) is [0,1] as the discriminator returns the probability that the input are actual data. Using this result, the optimal disctiminator D^\ast is

\begin{equation} D^\ast(\mathbf{x}) = \frac{p_{data} (\mathbf{x})}{p_{data} (\mathbf{x}) + p_{g} (\mathbf{x})}. \end{equation}

Next, we think about the generator. Given D^\ast , we reformulate the optimization problem, defined above, as:

\begin{align} C(G) &= \max_{D} V(G, D) \\ &= \mathbb{E}_{\mathbf{x}\sim p_{data}(\mathbf{x})}[ \log D^\ast(\mathbf{x}) ] + \mathbb{E}_{\mathbf{z}\sim p_{\mathbf{z}}(\mathbf{z})}[ \log (1-D^\ast(G(\mathbf{z}))) ] \\ &= \mathbb{E}_{\mathbf{x}\sim p_{data}(\mathbf{x})}[ \log D^\ast(\mathbf{x}) ] + \mathbb{E}_{\mathbf{x}\sim p_{g}(\mathbf{x})}[ \log (1-D^\ast(\mathbf{x})) ] \\ &= \mathbb{E}_{\mathbf{x}\sim p_{data}(\mathbf{x})}\left[ \log \frac{p_{data} (\mathbf{x})}{p_{data} (\mathbf{x}) + p_{g} (\mathbf{x})} \right] + \mathbb{E}_{\mathbf{x}\sim p_{g}(\mathbf{x})}\left[ \log\frac{p_{g} (\mathbf{x})}{p_{data} (\mathbf{x}) + p_{g} (\mathbf{x})} \right] \\ &= -\log(4) + \mathbb{E}_{\mathbf{x}\sim p_{data}(\mathbf{x})}\left[ \log\frac{p_{data} (\mathbf{x})}{\left\{p_{data} (\mathbf{x}) + p_{g} (\mathbf{x})\right\}/2} \right] + \mathbb{E}_{\mathbf{x}\sim p_{g}(\mathbf{x})}\left[ \log\frac{p_{g} (\mathbf{x})}{\left\{p_{data} (\mathbf{x}) + p_{g} (\mathbf{x})\right\}/2} \right] \\ &= -\log(4) + KL\left( p_{data} \| \frac{p_{data} (\mathbf{x}) + p_{g} (\mathbf{x})}{2} \right) + KL\left( p_{g} \| \frac{p_{data} (\mathbf{x}) + p_{g} (\mathbf{x})}{2} \right) \\ &= -\log(4) + 2 \cdot JSD(p_{data} \| p_g) \\ \end{align}

, where KL is the Kullback-Leibler divergence and JSD is the Jensen-Shannon divergence. Both of the Kullback-Leibler divergence and The Jensen-Shannon divergence indicate the difference between two distribution. The Jensen-Shannon divergence is zero only when those distribution is equal, and positive otherwise. Hence, at the minimum of C(G) , p_{g} = p_{data} is achieved. Namely, the generative adversarial network model perfectly replicates the actual data generating process at convergence.

Problems

Convergence at equilibrium is NOT guaranteed.

  • Oscillation
    • GAN's loss continue to oscillate even after many epochs.
  • Mode collapse
    • GAN generate only the part of the data continuing to remain at the local optimum if discriminator doesn't learn much enough.
  • Evaluation of learning
    • Decreasing either of the two loss functions does not directly mean improvement of learning.
    • Because how much each of generator and discriminator learn relative to each other is important.
  • Gradient vanishing on generator
    • Given D, minimizing generator's loss function is equivalent to minimizing \mathbb{E}_{\mathbf{z}\sim p_{\mathbf{z}}(\mathbf{z})}[ \log (1-D(G(\mathbf{z}))) ] , the part of V(D,G) .
    • If discriminator completely learn, D(G(\mathbf{z})) = 0 holds for any z and gradient of \mathbb{E}_{\mathbf{z}\sim p_{\mathbf{z}}(\mathbf{z})}[ \log (1-D(G(\mathbf{z}))) ] become very small.
    • However, discriminator have to learn first for mode collapse problem.
    • Then, \mathbb{E}_{\mathbf{z}\sim p_{\mathbf{z}}(\mathbf{z})}[ -\log D(G(\mathbf{z})) ] (, called "non-saturating" below) is usually used[4].

Empirical tricks

To alleviate the problems of GAN, several empirical tricks have been proposed.

  • Use batch normalization only in generator
  • But do not use BN at last layer of generator
  • Use dropout in discriminator
  • Calculate discriminator's accuracy for actual data and generated data
  • Evaluate learning reffering these accuracies rather than loss
  • Make discriminator learn first
  • Both of discriminator's accuracies should converge 0.5
  • If none of the data is replicated, it is possible that the network structure has problems

Hyper parameters

In [5]:
nc = 1            # of channels in image
nz = 100          # of the dimension of noise
nbatch = 256      # of examples in batch
ngfc = 1200       # of generator units for fully connected layers
ndfc = 240        # of discriminator units for fully connected layers
ndp = 5           # of dicrim units for maxout layers

Definition of networks

We define the networks function below. The discriminator's input is either of the actual data or the fake data generated by the generator. The discriminator's output is a logit, so the output sigmoid transoformed is the probability that the input is from the actual data. We note that using maxpool layers in the discriminator relies on the original paper[1], but a leaky relu function is usually used for activation functions in discriminator recently.

In [6]:
class discriminator(rm.Model):

    def __init__(self):
        super(discriminator, self).__init__()
        self._full1 = rm.Dense(ndfc*ndp, initializer=Uniform(-0.005, 0.005))
        self._full2 = rm.Dense(ndfc*ndp, initializer=Uniform(-0.005, 0.005))
        self._full3 = rm.Dense(1, initializer=Uniform(-0.005, 0.005))
        self._maxpool1 = rm.MaxPool2d(filter=(1, ndp), stride=(1, 1), padding=0)
        self._maxpool2 = rm.MaxPool2d(filter=(1, ndp), stride=(1, 1), padding=0)
        self.dr1 = rm.Dropout()
        self.dr2 = rm.Dropout()


    def forward(self, x):
        bs = len(x)
        h = rm.reshape(self._full1(x), (bs, nc, ndfc, ndp))
        h1 = rm.reshape(self.dr1(self._maxpool1(h)), (bs, -1))
        h1 = rm.reshape(self._full2(h1), (bs, nc, ndfc, ndp))
        h1 = rm.reshape(self.dr2(self._maxpool2(h1)), (bs, -1))
        h2 = self._full3(h1)
        return h2

The generator's input is a noise vector drawn from uniform distribution following the original paper[1] but normal distribution is usually used.

In [7]:
class generator(rm.Model):

    def __init__(self):
        super(generator, self).__init__()
        self._full1 = rm.Dense(ngfc, initializer=Uniform(-0.05, 0.05))
        self._full2 = rm.Dense(ngfc, initializer=Uniform(-0.05, 0.05))
        self._full3 = rm.Dense(npx*npx, initializer=Uniform(-0.05, 0.05))
        self.bn1 = rm.BatchNormalize()
        self.bn2 = rm.BatchNormalize()
        self.bn3 = rm.BatchNormalize()

    def forward(self, z):
        z = self.bn1(z)
        z1 = rm.relu(self._full1(z))
        z1 = self.bn2(z1)
        z2 = rm.relu(self._full2(z1))
        z2 = self.bn3(z2)
        z3 = self._full3(z2)
        x = rm.sigmoid(z3)
        return x

In the whole network function, minimax option indicates whether minus sigmoid cross entropy or \mathbb{E}_{\mathbf{z}\sim p_{\mathbf{z}}(\mathbf{z})}[ -\log D(G(\mathbf{z})) ] is used for the generator's loss function.

In [8]:
class gan_origin(rm.Model):
    def __init__(self, gen, dis, minimax=False):
        self.gen = gen
        self.dis = dis
        self.minimax = minimax

    def forward(self, x):
        bs = len(x)
        x = rm.reshape(x, (bs, -1))
        z = np.random.rand(bs*nz).reshape((bs, nz)).astype(np.float32)
        self.x_gen = self.gen(z)
        self.real_dis = self.dis(x)
        self.fake_dis = self.dis(self.x_gen)
        self.prob_real = rm.sigmoid(self.real_dis)
        self.prob_fake = rm.sigmoid(self.fake_dis)
        self.dis_loss_real = rm.sigmoid_cross_entropy(self.real_dis, np.ones(bs).reshape(-1,1))
        self.dis_loss_fake = rm.sigmoid_cross_entropy(self.fake_dis, np.zeros(bs).reshape(-1,1))
        self.dis_loss = self.dis_loss_real + self.dis_loss_fake
        if self.minimax:
            self.gen_loss = -self.dis_loss
        else: #non-saturating
            self.gen_loss = rm.sigmoid_cross_entropy(self.fake_dis, np.ones(bs).reshape(-1,1))

        return self.dis_loss
In [9]:
def imshow(image_set, nrows=4, ncols=10, figsize=(12.5, 5), save=False):
    plot_num = nrows * ncols
    fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(10, 10*nrows/ncols))
    plt.tight_layout(False)
    fig.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
    ax = ax.ravel()
    for i in range(plot_num):
        ax[i].imshow(-image_set[i].reshape(28, 28), cmap="gray")
        ax[i].set_xticks([])
        ax[i].set_yticks([])
    if save is not False:
        plt.savefig(save + ".png")

In [10]:
imshow(X_train)
../../../_images/notebooks_generative-model_GAN_notebook_40_0.png

Training loop

In [11]:
dis_opt = rm.Adam(b=0.5)
gen_opt = rm.Adam(b=0.5)

dis = discriminator()
gen = generator()
gan = gan_origin(gen, dis)

epoch = 30
N = len(X_train)
batch_size = 256
loss_curve_dis = []
loss_curve_gen = []
acc_curve_real = []
acc_curve_fake = []

test_image = []


start = time.time()
for i in range(1, epoch+1):
    perm = np.random.permutation(N)
    total_loss_dis = 0
    total_loss_gen = 0
    total_acc_real = 0
    total_acc_fake = 0
    for j in range(N//batch_size):
        index = perm[j*batch_size:(j+1)*batch_size]
        train_batch_X = X_train[index]
        with gan.train():
            dis_loss = gan(train_batch_X)
        with gan.gen.prevent_update():
            dl = gan.dis_loss
            dl.grad(detach_graph=False).update(dis_opt)
        with gan.dis.prevent_update():
            gl = gan.gen_loss
            gl.grad().update(gen_opt)
        real_acc = len(np.where(gan.prob_real.as_ndarray() >= 0.5)[0]) / batch_size
        fake_acc = len(np.where(gan.prob_fake.as_ndarray() < 0.5)[0]) / batch_size
        dis_loss_ = gan.dis_loss.as_ndarray()[0]
        gen_loss_ = gan.gen_loss.as_ndarray()[0]
        total_loss_dis += dis_loss_
        total_loss_gen += gen_loss_
        total_acc_real += real_acc
        total_acc_fake += fake_acc
        print('{:05}/{:05} Dis:{:.3f} Gen:{:.3f} Real:{:.3f} Fake:{:.3f}'.format(
            j*batch_size, N, dis_loss_, gen_loss_, real_acc, fake_acc), flush=True, end='\r')
    loss_curve_dis.append(total_loss_dis/(N//batch_size))
    loss_curve_gen.append(total_loss_gen/(N//batch_size))
    acc_curve_real.append(total_acc_real/(N//batch_size))
    acc_curve_fake.append(total_acc_fake/(N//batch_size))

    elapsed_time = time.time() - start

    if i%1 == 0:
        print("Epoch %02d - Loss of Dis:%f - Loss of Gen:%f - Elapsed time:%f"
              %(i, loss_curve_dis[-1], loss_curve_gen[-1], elapsed_time))
        print("         - Accuracy of Real:%f - Accuracy of Fake:%f"
              %(acc_curve_real[-1], acc_curve_fake[-1]))


    test_dim = 10
    gan(X_train[:test_dim**2])
    test_image.append(gan.x_gen.as_ndarray().reshape(test_dim**2, npx, npx))
Epoch 01 - Loss of Dis:0.719061 - Loss of Gen:2.397856 - Elapsed time:14.514677
         - Accuracy of Real:0.785176 - Accuracy of Fake:0.902945
Epoch 02 - Loss of Dis:0.377584 - Loss of Gen:2.544531 - Elapsed time:27.803916
         - Accuracy of Real:0.934615 - Accuracy of Fake:0.944511
Epoch 03 - Loss of Dis:0.540502 - Loss of Gen:2.547249 - Elapsed time:41.026845
         - Accuracy of Real:0.864804 - Accuracy of Fake:0.919631
Epoch 04 - Loss of Dis:0.682744 - Loss of Gen:2.059768 - Elapsed time:54.146722
         - Accuracy of Real:0.820272 - Accuracy of Fake:0.887500
Epoch 05 - Loss of Dis:0.731340 - Loss of Gen:1.900525 - Elapsed time:67.067616
         - Accuracy of Real:0.798918 - Accuracy of Fake:0.881751
Epoch 06 - Loss of Dis:0.739707 - Loss of Gen:1.841624 - Elapsed time:80.087500
         - Accuracy of Real:0.793910 - Accuracy of Fake:0.882031
Epoch 07 - Loss of Dis:0.715174 - Loss of Gen:1.885906 - Elapsed time:93.167669
         - Accuracy of Real:0.797075 - Accuracy of Fake:0.893450
Epoch 08 - Loss of Dis:0.754820 - Loss of Gen:1.787846 - Elapsed time:106.498410
         - Accuracy of Real:0.786579 - Accuracy of Fake:0.884255
Epoch 09 - Loss of Dis:0.780429 - Loss of Gen:1.722660 - Elapsed time:119.919572
         - Accuracy of Real:0.775701 - Accuracy of Fake:0.879828
Epoch 10 - Loss of Dis:0.762232 - Loss of Gen:1.729577 - Elapsed time:133.337572
         - Accuracy of Real:0.777684 - Accuracy of Fake:0.889463
Epoch 11 - Loss of Dis:0.827336 - Loss of Gen:1.605738 - Elapsed time:146.463208
         - Accuracy of Real:0.753425 - Accuracy of Fake:0.871835
Epoch 12 - Loss of Dis:0.836076 - Loss of Gen:1.578897 - Elapsed time:159.570772
         - Accuracy of Real:0.751603 - Accuracy of Fake:0.869591
Epoch 13 - Loss of Dis:0.861161 - Loss of Gen:1.529864 - Elapsed time:172.706551
         - Accuracy of Real:0.736719 - Accuracy of Fake:0.862961
Epoch 14 - Loss of Dis:0.855765 - Loss of Gen:1.537340 - Elapsed time:185.759237
         - Accuracy of Real:0.738281 - Accuracy of Fake:0.863542
Epoch 15 - Loss of Dis:0.847188 - Loss of Gen:1.552502 - Elapsed time:198.824406
         - Accuracy of Real:0.746534 - Accuracy of Fake:0.866386
Epoch 16 - Loss of Dis:0.888899 - Loss of Gen:1.491426 - Elapsed time:212.166279
         - Accuracy of Real:0.729006 - Accuracy of Fake:0.853946
Epoch 17 - Loss of Dis:0.903979 - Loss of Gen:1.453521 - Elapsed time:225.533276
         - Accuracy of Real:0.725821 - Accuracy of Fake:0.847336
Epoch 18 - Loss of Dis:0.863168 - Loss of Gen:1.545666 - Elapsed time:238.840992
         - Accuracy of Real:0.739002 - Accuracy of Fake:0.858534
Epoch 19 - Loss of Dis:0.870038 - Loss of Gen:1.525828 - Elapsed time:252.181412
         - Accuracy of Real:0.734796 - Accuracy of Fake:0.857953
Epoch 20 - Loss of Dis:0.837272 - Loss of Gen:1.589309 - Elapsed time:265.469788
         - Accuracy of Real:0.748958 - Accuracy of Fake:0.867528
Epoch 21 - Loss of Dis:0.791500 - Loss of Gen:1.681866 - Elapsed time:278.726297
         - Accuracy of Real:0.761879 - Accuracy of Fake:0.882812
Epoch 22 - Loss of Dis:0.774528 - Loss of Gen:1.737219 - Elapsed time:291.985094
         - Accuracy of Real:0.773177 - Accuracy of Fake:0.883554
Epoch 23 - Loss of Dis:0.747863 - Loss of Gen:1.782163 - Elapsed time:305.364492
         - Accuracy of Real:0.780349 - Accuracy of Fake:0.895413
Epoch 24 - Loss of Dis:0.764319 - Loss of Gen:1.703816 - Elapsed time:318.692298
         - Accuracy of Real:0.770733 - Accuracy of Fake:0.892488
Epoch 25 - Loss of Dis:0.758058 - Loss of Gen:1.741920 - Elapsed time:332.052672
         - Accuracy of Real:0.777804 - Accuracy of Fake:0.892127
Epoch 26 - Loss of Dis:0.832977 - Loss of Gen:1.610022 - Elapsed time:345.668877
         - Accuracy of Real:0.747276 - Accuracy of Fake:0.872456
Epoch 27 - Loss of Dis:0.851042 - Loss of Gen:1.574273 - Elapsed time:359.241927
         - Accuracy of Real:0.744391 - Accuracy of Fake:0.866727
Epoch 28 - Loss of Dis:0.804519 - Loss of Gen:1.648678 - Elapsed time:372.488189
         - Accuracy of Real:0.758574 - Accuracy of Fake:0.879928
Epoch 29 - Loss of Dis:0.830238 - Loss of Gen:1.607819 - Elapsed time:385.928092
         - Accuracy of Real:0.749459 - Accuracy of Fake:0.871975
Epoch 30 - Loss of Dis:0.840108 - Loss of Gen:1.586321 - Elapsed time:398.959127
         - Accuracy of Real:0.744732 - Accuracy of Fake:0.869812
In [12]:
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))
ax = ax.ravel()
ax[0].plot(loss_curve_dis, linewidth=2, label="dis")
ax[0].plot(loss_curve_gen, linewidth=2, label="gen")
ax[0].set_title("Learning curve")
ax[0].set_ylabel("loss")
ax[0].set_xlabel("epoch")
ax[0].legend()
ax[0].grid()
ax[1].plot(acc_curve_real, linewidth=2, label="real")
ax[1].plot(acc_curve_fake, linewidth=2, label="fake")
ax[1].set_title("Accuracy curve")
ax[1].set_ylabel("accuracy")
ax[1].set_xlabel("epoch")
ax[1].legend()
ax[1].grid()
plt.show()
../../../_images/notebooks_generative-model_GAN_notebook_43_0.png

Results

In [13]:
imshow(test_image[-1], test_dim, test_dim)
../../../_images/notebooks_generative-model_GAN_notebook_45_0.png

Mode collapse case

References

[1] Goodfellow, Ian, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. "Generative adversarial nets." In Advances in neural information processing systems, 2014. https://arxiv.org/abs/1406.2661

[2] Karras, Tero, Timo Aila, Samuli Laine, and Jaakko Lehtinen. "Progressive growing of gans for improved quality, stability, and variation." arXiv preprint arXiv:1710.10196 (2017). https://arxiv.org/abs/1710.10196

[3] Isola, Phillip, Jun-Yan Zhu, Tinghui Zhou, and Alexei A. Efros. "Image-to-image translation with conditional adversarial networks." arXiv preprint (2017). https://arxiv.org/abs/1611.07004

[4] Goodfellow, Ian. "NIPS 2016 tutorial: Generative adversarial networks." arXiv preprint arXiv:1701.00160 (2016). https://arxiv.org/abs/1701.00160