Variational Auto-Encoder

Variational Bayesian inference by “deterministic” neural networks.

Introduction

What is variational auto-encoder?

The concept of this section is based on D. Kingma & M. Welling(ICLR, 2014) . For more details, please refer to the paper.

As seen in Auto Encoder (for short, AE) section, AE can extract features whose dimensionality is much fewer than the samples (e.g. from 784 to 10 in MNIST, see the left figure below). Variational auto-encoder (VAE) is a type of the variational Bayesian inference which have encoder and decoder structure as with AE.

AE&VAE

AE&VAE

We assume that observed values x is generated from latent variable z by next 2 steps.

  • z is generated from the prior distribution p_{\theta^*}(z)
  • x is generated from the conditional distribution p_{\theta^*}(x \mid z)

\theta^* is the true generative parameter, where our purpose is to infer the approximate parameter \theta and the value of latent variables z . However, we have some difficulty in calculation such as marginal likelihood p_\theta(x) = \int p_ \theta(z)p_\theta(x \mid z) dz or the true posterior p_\theta(z\mid x) = p_\theta(z)p_\theta(x \mid z) / p_\theta(x) .

The variational lower bound

We solve this intractability by introducing a recognition model q_{\phi}(z\mid x) which approximates to the true posterior. \phi is the variational parameter. The marginal likelihood is written as \log p_\theta(x) = \sum_{i=1}^{N} \log p_\theta(x^{(i)}) , which each component can be rewritten as

\begin{equation} \log p_\theta(x^{(i)}) = D_{KL}(q_{\phi}(z \mid x^{(i)}) \| p_{\theta}(z \mid x^{(i)})) + {\mathcal L}(\theta, \phi ; x^{(i)}) \end{equation}

The 1st RHS term is the KL divergence from the true posterior to the approximate. This takes non-negative value. The 2nd term is called the variational lower bound on the marginal likelihood of datapoint i . This can be written as below by simple calculation

\begin{equation} {\mathcal L}(\theta, \phi ; x^{(i)}) = - D_{KL}(q_{\phi}(z \mid x^{(i)}) \| p_{\theta}(z)) + \mathbb{E}_{q_{\phi}(z \mid x^{(i)}) \| p_{\theta}(z))}[\log p_{\theta}(x^{(i)} \mid z))] \end{equation}

Our purpose is to maximize the first marginal likelihood, which means the maximization of the lower bound. If you tend to use gradient discent, the derivative of the lower bound with respect to \phi is usually too costly by such as Monte Carlo method. We avoid this problem using “reparametrization trick” shown after. Now let’s prepare for VAE.

graphical model

graphical model

You can download Fashion-MNIST data as csv-file easily from Kaggle .

Required libraries

In [1]:
%matplotlib inline
import time
import renom as rm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelBinarizer
In [2]:
from renom.cuda.cuda import set_cuda_active
set_cuda_active(True)

Load fashion-MNIST data

In [3]:
label_names = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
               "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
In [4]:
df_train = pd.read_csv("../dataset/fashionmnist/fashion-mnist_train.csv")
df_test = pd.read_csv("../dataset/fashionmnist/fashion-mnist_test.csv")
In [5]:
df_train.head()
Out[5]:
label pixel1 pixel2 pixel3 pixel4 pixel5 pixel6 pixel7 pixel8 pixel9 ... pixel775 pixel776 pixel777 pixel778 pixel779 pixel780 pixel781 pixel782 pixel783 pixel784
0 2 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
1 9 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
2 6 0 0 0 0 0 0 0 5 0 ... 0 0 0 30 43 0 0 0 0 0
3 0 0 0 0 1 2 0 0 0 0 ... 3 0 0 0 0 1 0 0 0 0
4 3 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0

5 rows × 785 columns

The label is shown in the first column and the second to the end columns indicates each pixel’s value. Now, we define a function to make numpy array data from the above dataframe.

In [6]:
def make_dataset(df):
    # label data
    labels = df.label.values.reshape(-1, 1)

    # image data
    images = []
    for i in range(len(df.index)):
        img = df.iloc[i, 1:].values
        img = img / 255. # normalize between 0-1
        images.append(img)
    # convert to numpy-array
    images = np.array(images, dtype=np.float32).reshape(-1, 784)

    return images, labels
In [7]:
X_train, y_train = make_dataset(df_train)
X_test, y_test = make_dataset(df_test)
In [8]:
print("X_train:%s y_train:%s\nX_test:%s y_test:%s"
      % (X_train.shape, y_train.shape, X_test.shape, y_test.shape))
X_train:(60000, 784) y_train:(60000, 1)
X_test:(10000, 784) y_test:(10000, 1)

We also define a function to show some images in a given dataset.

In [9]:
def imshow(image_set, nrows=4, ncols=10, figsize=(12.5, 5)):
    plot_num = nrows * ncols
    fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(10, 10*nrows/ncols))
    plt.tight_layout(False)
    fig.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
    ax = ax.ravel()
    for i in range(plot_num):
        ax[i].imshow(-image_set[i].reshape(28, 28), cmap="gray")
        ax[i].set_xticks([])
        ax[i].set_yticks([])
In [10]:
imshow(X_train)
../../../_images/notebooks_generative-model_VAE_notebook_24_0.png

Reparametrization trick

In order to avoid the intractability of optimizing the lower bound with respect to both the variational parameter \phi and the generative parameter \theta (this means we cannot execute backpropagation), we introduce reparametrization trick . Instead of random sampling z ~\sim {\mathcal N}(\mu, \sigma^2 I) , we generate random noise \varepsilon ~\sim {\mathcal N}(0, I) , and put z = \mu + \sigma^2 \odot \varepsilon . Here \odot is the element-wise product. By using this technique, z can be sampled deterministically with respect to \mu and \sigma^2 , which enables backprop.

Variational Auto-Encoder

Variational Auto-Encoder

Now, let’s construct a VAE model.

In [11]:
class VAE(rm.Model):

    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        # encoder
        self.e1 = rm.Dense(400)
        self.e2 = rm.Dense(200)
        self.e3 = rm.Dense(100)
        self.e4 = rm.Dense(20)
        self.em = rm.Dense(latent_dim)
        self.ev = rm.Dense(latent_dim)

        # decoder
        self.d1 = rm.Dense(20)
        self.d2 = rm.Dense(100)
        self.d3 = rm.Dense(200)
        self.d4 = rm.Dense(400)
        self.d5 = rm.Dense(784)

    def encode(self, x):
        z = rm.relu(self.e1(x))
        z = rm.relu(self.e2(z))
        z = rm.relu(self.e3(z))
        z = rm.relu(self.e4(z))
        return self.em(z), self.ev(z)

    # sample z from mu & log(sigma^2)
    def reparametrize(self, mu, logvar):
        std = rm.exp(logvar * 0.5)
        eps = np.random.randn(logvar.shape[0], logvar.shape[1])
        return mu + eps * std

    def decode(self, x):
        z = rm.relu(self.d1(x))
        z = rm.relu(self.d2(z))
        z = rm.relu(self.d3(z))
        z = rm.relu(self.d4(z))
        return rm.sigmoid(self.d5(z))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar

Define a loss function

A loss function for VAE is composed of 2 terms.

\begin{equation} l_{\rm total} = l_{\rm recon} + D_{KL}(q_{\phi}(z \mid x) \| p_{\theta}(z)) \end{equation}

The 1st RHS term l_{\rm recon} is reconstruction error between input images and reconstructed images, same as ordinary AutoEncoder. The 2nd term indicates KL-divergence from the prior distribution p_{\theta}(z) to the posterior q_{\theta}(z \mid x) .

In the simple case that both the prior p_{\theta}(z) and the posterior q_{\phi}(z\mid x) are Gaussian, the KL term can be written as below. For a detailed calculation, see Appendix: B in the paper.

\begin{equation} D_{KL}(q_{\phi}(z \mid x) \| p_{\theta}(z)) = - \frac{1}{2} \sum_{j=1}^J (1 + \log((\sigma_j)^2) - (\mu_j)^2 - (\sigma_j)^2) \end{equation}

Here J is the dimensionality of latent z and j indicates the each component.

In [12]:
def loss_function(recon_x, x, mu, logvar, batch_size):
    # Mean-squared-error between reconstructed images & input images
    MSE = rm.mean_squared_error(recon_x, x)
    # KL-divergence from the prior distribution p(z) to the posterior q(z|x)
    KLD = - 0.5 * rm.sum(1 + logvar - mu ** 2 - rm.exp(logvar)) / (batch_size)
    total_loss = MSE + KLD
    return MSE, KLD, total_loss

Train the model

In [13]:
# initialize the VAE model assuming 2 latent-dim
model = VAE(2)

# define optimizer function
optimizer = rm.Adam()
In [14]:
# set parameters
EPOCH = 60
BATCH_SIZE = 200
TEST_BATCH_SIZE = 1000
N = X_train.shape[0] # train data size
M = X_test.shape[0]  # test data size
num_batch_train = N // BATCH_SIZE
num_batch_test  = M // TEST_BATCH_SIZE

# learning curves
train_MSE_curve, train_KLD_curve, train_loss_curve = [], [], []
test_MSE_curve, test_KLD_curve, test_loss_curve = [], [], []

total_start = time.time()
# train loop
for i in range(1, 1+EPOCH):

    start = time.time()
    perm = np.random.permutation(N)
    train_MSE ,train_KLD ,train_loss = 0, 0, 0

    for j in range(num_batch_train):
        X_batch = X_train[perm[j*BATCH_SIZE:(j+1)*BATCH_SIZE]]

        # forward propagation
        with model.train():
            recon_batch, mu, logvar = model(X_batch)
            mse, kld, loss = loss_function(recon_batch, X_batch, mu, logvar, BATCH_SIZE)

        # backpropagation
        loss.grad().update(optimizer)

        train_MSE += mse.as_ndarray()
        train_KLD += kld.as_ndarray()
        train_loss += loss.as_ndarray()

    train_MSE /= num_batch_train
    train_KLD /= num_batch_train
    train_loss /= num_batch_train

    # test validation
    test_MSE, test_KLD, test_loss = 0, 0, 0
    for j in range(num_batch_test):
        test_batch = X_test[j*TEST_BATCH_SIZE:(j+1)*TEST_BATCH_SIZE]
        recon_test, mu_test, logvar_test = model(test_batch)
        mse, kld, loss = loss_function(recon_test, test_batch, mu_test, logvar_test, TEST_BATCH_SIZE)
        test_MSE  += mse.as_ndarray()
        test_KLD  += kld.as_ndarray()
        test_loss += loss.as_ndarray()

    test_MSE /= num_batch_test
    test_KLD /= num_batch_test
    test_loss /= num_batch_test

    # record a log
    train_MSE_curve.append(train_MSE)
    train_KLD_curve.append(train_KLD)
    train_loss_curve.append(train_loss)
    test_MSE_curve.append(test_MSE)
    test_KLD_curve.append(test_KLD)
    test_loss_curve.append(test_loss)

    elapsed = time.time() - start
    print("epoch %02d - %.1fs - loss %f - test_loss %f (MSE %f | KLD %f)"
          % (i, elapsed, train_loss, test_loss, test_MSE, test_KLD))
epoch 01 - 31.3s - loss 29.292006 - test_loss 23.355804 (MSE 20.918629 | KLD 2.437176)
epoch 02 - 26.0s - loss 21.881132 - test_loss 21.035431 (MSE 18.023841 | KLD 3.011590)
epoch 03 - 26.0s - loss 20.157909 - test_loss 19.423353 (MSE 15.813090 | KLD 3.610263)
epoch 04 - 26.0s - loss 18.991537 - test_loss 18.824402 (MSE 15.328443 | KLD 3.495959)
epoch 05 - 26.0s - loss 18.368986 - test_loss 18.323238 (MSE 14.764117 | KLD 3.559122)
epoch 06 - 26.0s - loss 17.861025 - test_loss 17.551872 (MSE 13.947551 | KLD 3.604319)
epoch 07 - 26.0s - loss 17.403337 - test_loss 17.386518 (MSE 13.809210 | KLD 3.577309)
epoch 08 - 26.1s - loss 17.022255 - test_loss 17.088886 (MSE 13.103191 | KLD 3.985694)
epoch 09 - 25.5s - loss 16.702307 - test_loss 16.500235 (MSE 12.550611 | KLD 3.949624)
epoch 10 - 25.8s - loss 16.450447 - test_loss 16.399370 (MSE 12.371619 | KLD 4.027751)
epoch 11 - 25.8s - loss 16.361910 - test_loss 16.549110 (MSE 12.418163 | KLD 4.130946)
epoch 12 - 25.9s - loss 16.281139 - test_loss 16.272774 (MSE 12.218720 | KLD 4.054053)
epoch 13 - 25.8s - loss 16.189659 - test_loss 16.299711 (MSE 12.575751 | KLD 3.723960)
epoch 14 - 25.8s - loss 16.111042 - test_loss 16.202089 (MSE 12.203213 | KLD 3.998876)
epoch 15 - 25.8s - loss 16.063742 - test_loss 16.148630 (MSE 11.959878 | KLD 4.188750)
epoch 16 - 25.8s - loss 16.152796 - test_loss 16.300724 (MSE 12.155709 | KLD 4.145014)
epoch 17 - 25.8s - loss 15.969066 - test_loss 16.214321 (MSE 11.820733 | KLD 4.393588)
epoch 18 - 25.8s - loss 15.954565 - test_loss 15.959585 (MSE 11.853073 | KLD 4.106514)
epoch 19 - 25.8s - loss 15.878405 - test_loss 15.907331 (MSE 11.661326 | KLD 4.246005)
epoch 20 - 25.8s - loss 15.864445 - test_loss 16.062075 (MSE 11.917901 | KLD 4.144173)
epoch 21 - 25.8s - loss 15.860448 - test_loss 16.399574 (MSE 12.238612 | KLD 4.160963)
epoch 22 - 25.8s - loss 15.803229 - test_loss 15.913763 (MSE 11.692834 | KLD 4.220929)
epoch 23 - 25.8s - loss 15.765581 - test_loss 16.204538 (MSE 11.937332 | KLD 4.267207)
epoch 24 - 25.8s - loss 16.012980 - test_loss 16.049160 (MSE 11.706032 | KLD 4.343127)
epoch 25 - 25.8s - loss 15.797098 - test_loss 15.893946 (MSE 11.503478 | KLD 4.390470)
epoch 26 - 25.8s - loss 15.674995 - test_loss 15.865252 (MSE 11.529056 | KLD 4.336195)
epoch 27 - 25.8s - loss 15.779749 - test_loss 15.717323 (MSE 11.648937 | KLD 4.068388)
epoch 28 - 25.8s - loss 15.614095 - test_loss 15.723486 (MSE 11.553985 | KLD 4.169500)
epoch 29 - 25.8s - loss 15.593273 - test_loss 15.731768 (MSE 11.340910 | KLD 4.390857)
epoch 30 - 25.8s - loss 15.603667 - test_loss 15.585861 (MSE 11.337660 | KLD 4.248201)
epoch 31 - 25.8s - loss 15.550684 - test_loss 15.630290 (MSE 11.385874 | KLD 4.244417)
epoch 32 - 25.8s - loss 15.557311 - test_loss 15.800154 (MSE 11.505608 | KLD 4.294548)
epoch 33 - 25.8s - loss 15.545179 - test_loss 15.559176 (MSE 11.300625 | KLD 4.258552)
epoch 34 - 25.8s - loss 15.531303 - test_loss 15.525171 (MSE 11.286374 | KLD 4.238796)
epoch 35 - 25.8s - loss 15.474730 - test_loss 15.715027 (MSE 11.472173 | KLD 4.242855)
epoch 36 - 25.8s - loss 15.492939 - test_loss 15.683006 (MSE 11.362078 | KLD 4.320929)
epoch 37 - 25.8s - loss 15.436321 - test_loss 15.597822 (MSE 11.351393 | KLD 4.246430)
epoch 38 - 25.8s - loss 15.456154 - test_loss 15.519873 (MSE 11.285704 | KLD 4.234169)
epoch 39 - 25.8s - loss 15.513895 - test_loss 15.487173 (MSE 10.995080 | KLD 4.492094)
epoch 40 - 25.8s - loss 15.501789 - test_loss 15.558016 (MSE 11.280224 | KLD 4.277792)
epoch 41 - 25.8s - loss 15.463036 - test_loss 15.468918 (MSE 11.109931 | KLD 4.358987)
epoch 42 - 25.8s - loss 15.485232 - test_loss 15.594754 (MSE 11.431299 | KLD 4.163455)
epoch 43 - 25.8s - loss 15.384259 - test_loss 15.390335 (MSE 11.046783 | KLD 4.343551)
epoch 44 - 25.8s - loss 15.413895 - test_loss 15.426082 (MSE 11.096325 | KLD 4.329757)
epoch 45 - 25.8s - loss 15.414681 - test_loss 15.417727 (MSE 10.920183 | KLD 4.497544)
epoch 46 - 25.8s - loss 15.416509 - test_loss 15.622869 (MSE 11.132002 | KLD 4.490869)
epoch 47 - 25.8s - loss 15.489601 - test_loss 15.560633 (MSE 11.164411 | KLD 4.396220)
epoch 48 - 25.8s - loss 15.477632 - test_loss 15.449161 (MSE 11.210119 | KLD 4.239041)
epoch 49 - 25.8s - loss 15.398214 - test_loss 15.399889 (MSE 11.240342 | KLD 4.159546)
epoch 50 - 25.8s - loss 15.383414 - test_loss 15.408064 (MSE 11.131777 | KLD 4.276286)
epoch 51 - 25.8s - loss 15.435960 - test_loss 15.378553 (MSE 10.955298 | KLD 4.423256)
epoch 52 - 25.8s - loss 15.419683 - test_loss 15.494047 (MSE 11.012778 | KLD 4.481270)
epoch 53 - 25.8s - loss 15.402197 - test_loss 15.588193 (MSE 11.174954 | KLD 4.413240)
epoch 54 - 25.8s - loss 15.355996 - test_loss 15.387619 (MSE 11.015935 | KLD 4.371685)
epoch 55 - 25.8s - loss 15.392119 - test_loss 15.418429 (MSE 10.925451 | KLD 4.492978)
epoch 56 - 25.8s - loss 15.346709 - test_loss 15.429564 (MSE 11.111563 | KLD 4.318002)
epoch 57 - 25.8s - loss 15.316667 - test_loss 15.386075 (MSE 10.964458 | KLD 4.421617)
epoch 58 - 25.8s - loss 15.363145 - test_loss 15.495132 (MSE 11.136416 | KLD 4.358716)
epoch 59 - 25.8s - loss 15.321315 - test_loss 15.304787 (MSE 10.864760 | KLD 4.440026)
epoch 60 - 25.8s - loss 15.303391 - test_loss 15.255893 (MSE 10.834188 | KLD 4.421706)
In [15]:
plt.figure(figsize=(16, 6))
x = range(1, EPOCH+1)
plt.plot(x, train_MSE_curve, "b:", label="train MSE")
plt.plot(x, train_KLD_curve, "b--", label="train KLD")
plt.plot(x, train_loss_curve, "b", label="train loss")
plt.plot(x, test_MSE_curve, "r:", label="test MSE")
plt.plot(x, test_KLD_curve, "r--", label="test KLD")
plt.plot(x, test_loss_curve, "r", label="test loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid()
../../../_images/notebooks_generative-model_VAE_notebook_36_0.png

The KL-divergence term regularizes the total loss.

Reconstructed images

In [16]:
recon_img, mu, logvar = model(X_test)
recon_img = recon_img.as_ndarray()
mu = mu.as_ndarray()
imshow(recon_img[:40])
../../../_images/notebooks_generative-model_VAE_notebook_39_0.png

Corresponding input raw images

In [17]:
imshow(X_test[:40])
../../../_images/notebooks_generative-model_VAE_notebook_41_0.png

Check the latent variable space

In [18]:
test_label = df_test.label.values
markers = ["*", "s", "v", "^", "+", "o", "P", "D", "p", "d"]
colors = []
cm = plt.get_cmap('jet')
for i in range(10):
    colors.append(cm(1.*i/10))
In [19]:
plt.figure(figsize=(12, 12))
for i, color, marker, label_name in zip(list(range(10)), colors, markers, label_names):
    plt.scatter(mu[test_label==i, 0], mu[test_label==i, 1],
                c=color, marker=marker, label=label_name)
plt.legend()
plt.grid()
../../../_images/notebooks_generative-model_VAE_notebook_44_0.png

From the above figure, we can say that - Even though unsupervised, some classes make good Gaussian distribution such as trouser, dresses and ankle-boots. - Sandals, sneakers and ankle-boots share some parts of their each distribution, but they also make one “shoes” cluster. - The Shirt distribution overwrapps other clothes (pullovers, coats, t-shirts/tops and dresses).

In [20]:
size = 20
xrange = np.linspace(2, -2, size)
yrange = np.linspace(-2, 2, size)
zmap = []

for i, x in enumerate(xrange):
    for j, y in enumerate(yrange):
        zmap.append([y, x])

output = model.decode(np.array(zmap))
In [21]:
fig, ax = plt.subplots(size, size, figsize=(10, 10))
ax = ax.ravel()
plt.tight_layout(False)
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0)
for i in range(size**2):
    ax[i].imshow(1-output[i].as_ndarray().reshape(28, 28), cmap="gray")
    ax[i].set_xticks([])
    ax[i].set_yticks([])
../../../_images/notebooks_generative-model_VAE_notebook_47_0.png
In [23]:
# Save weights for reproducibility
model.save("vae.hd5")