変分オートエンコーダ

決定論的ニューラルネットワークによる変分ベイズ推定

イントロダクション

変分オートエンコーダとは?

このチュートリアルの内容は論文 D. Kingma & M. Welling(ICLR, 2014) に準拠したものです。詳細は本論文を参照してください。

Auto Encoder (以降、AEと略記)のセクションで見た通り、AEは入力サンプルのものよりも圧倒的に少ない次元数の特徴を抽出することができます(例えばMNISTでは784次元から10次元を抽出)。変分オートエンコーダ(Variational auto-encoder: VAE)は変分ベイズ推定法の一種で、通常のオートエンコーダと同様、エンコーダ部分とデコーダ部分を持つモデルです。

AE&VAE

オートエンコーダと変分オートエンコーダ

観測値 x は潜在変数 z から次の2ステップから生成されていると仮定します。

  • z は事前分布 p_{\theta^*}(z) から生成される

  • x は条件付き確率分布 p_{\theta^*}(x\mid z) から生成される

\theta^* は真の生成パラメータで、 私たちの目的はこれを近似するパラメータ \theta および潜在変数 z を推定することです。しかし、周辺尤度 p_\theta(x) = \int p_\theta(z)p_\theta(x \mid z)dz や真の事後分布 p_\theta(z\mid x) = p_\theta(z)p_\theta(x \mid z) / p_\theta(x) を陽に計算することは困難です。

変分下限

この問題を解決するため、真の事後分布を近似する分布関数として q_{\phi}(z\mid x) を導入します。ここで \phi は変分パラメータです。周辺尤度はこれにより \log p_\theta(x) = \sum_{i=1}^{N} \log p_\theta(x^{(i)}) と書け、総和をとるそれぞれの成分はさらに以下のように表せます。

\begin{equation} \log p_\theta(x^{(i)}) = D_{KL}(q_{\phi}(z \mid x^{(i)}) \| p_{\theta}(z \mid x^{(i)})) + {\mathcal L}(\theta, \phi ; x^{(i)}) \end{equation}

右辺第一項は真の事後分布から推定した事後分布へのKLダイバージェンスを表しており、非負の値をとります。第二項はデータ点 i における周辺尤度の「変分下限」と呼ばれる量で、簡単な計算により以下のように展開できます。

\begin{equation} {\mathcal L}(\theta, \phi ; x^{(i)}) = - D_{KL}(q_{\phi}(z \mid x^{(i)}) \| p_{\theta}(z)) + \mathbb{E}_{q_{\phi}(z \mid x^{(i)}) \| p_{\theta}(z))}[\log p_{\theta}(x^{(i)} \mid z))] \end{equation}

ここでの目的は、上側の式で出てきた周辺尤度の最大化であり、これは変分下限を最大化することになります。しかしながら、例えば勾配降下法による変分下限の最適化をしようとすると、 \phi による偏微分はモンテカルロ法などによって求める必要があり、計算コストが莫大にかかります。この問題は後に見る「reparametrization trick」を用いて解決することができます。それでは、まずVAEの準備に入りましょう。

graphical model

graphical model

Fashiom MNISTのデータは Kaggle からcsvファイルからダウンロードできます。

必要なライブラリ

  • matplotlib 2.0.2
  • numpy 1.12.1
  • scikit-learn 0.18.2
  • pandas 0.20.3
In [1]:
%matplotlib inline
import time
import renom as rm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelBinarizer
In [2]:
from renom.cuda.cuda import set_cuda_active
set_cuda_active(True)

Fashion MNISTのデータを読み込む

In [3]:
label_names = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
               "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
In [4]:
df_train = pd.read_csv("../dataset/fashionmnist/fashion-mnist_train.csv")
df_test = pd.read_csv("../dataset/fashionmnist/fashion-mnist_test.csv")
In [5]:
df_train.head()
Out[5]:
label pixel1 pixel2 pixel3 pixel4 pixel5 pixel6 pixel7 pixel8 pixel9 ... pixel775 pixel776 pixel777 pixel778 pixel779 pixel780 pixel781 pixel782 pixel783 pixel784
0 2 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
1 9 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
2 6 0 0 0 0 0 0 0 5 0 ... 0 0 0 30 43 0 0 0 0 0
3 0 0 0 0 1 2 0 0 0 0 ... 3 0 0 0 0 1 0 0 0 0
4 3 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0

5 rows × 785 columns

1列目にラベルが示されており、2列目から最終列までは1画像における各ピクセルの値を表しています。上に表示したデータフレームからnumpy配列に変換しましょう。

In [6]:
def make_dataset(df):
    # label data
    labels = df.label.values.reshape(-1, 1)

    # image data
    images = []
    for i in range(len(df.index)):
        img = df.iloc[i, 1:].values
        img = img / 255. # normalize between 0-1
        images.append(img)
    # convert to numpy-array
    images = np.array(images, dtype=np.float32).reshape(-1, 784)

    return images, labels
In [7]:
X_train, y_train = make_dataset(df_train)
X_test, y_test = make_dataset(df_test)
In [8]:
print("X_train:%s y_train:%s\nX_test:%s y_test:%s"
      % (X_train.shape, y_train.shape, X_test.shape, y_test.shape))
X_train:(60000, 784) y_train:(60000, 1)
X_test:(10000, 784) y_test:(10000, 1)

入力したデータセットからいくつか画像を表示する関数も定義します。

In [9]:
def imshow(image_set, nrows=4, ncols=10, figsize=(12.5, 5)):
    plot_num = nrows * ncols
    fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(10, 10*nrows/ncols))
    plt.tight_layout(False)
    fig.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
    ax = ax.ravel()
    for i in range(plot_num):
        ax[i].imshow(-image_set[i].reshape(28, 28), cmap="gray")
        ax[i].set_xticks([])
        ax[i].set_yticks([])
In [10]:
imshow(X_train)
../../../_images/notebooks_generative-model_VAE_notebook_24_0.png

Reparametrization trick

変分下限の最適化における計算困難性(=逆伝播の困難)を解決するため、Kingmaらが導入した reparametrization trick を用います。 z のランダムサンプリング z ~\sim {\mathcal N}(\mu, \sigma^2 I) の代わりに、ランダムノイズ \varepsilon ~\sim {\mathcal N}(0, I) を用いて z = \mu + \sigma^2 \odot \varepsilon とします。ここで odot は要素積です。この「トリック」を用いることにより、 z \mu \sigma^2 に対して確定的となり、逆伝播を可能とします。

Variational Auto-Encoder

変分オートエンコーダ

それではVAEモデルを構築しましょう。

In [11]:
class VAE(rm.Model):

    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        # encoder
        self.e1 = rm.Dense(400)
        self.e2 = rm.Dense(200)
        self.e3 = rm.Dense(100)
        self.e4 = rm.Dense(20)
        self.em = rm.Dense(latent_dim)
        self.ev = rm.Dense(latent_dim)

        # decoder
        self.d1 = rm.Dense(20)
        self.d2 = rm.Dense(100)
        self.d3 = rm.Dense(200)
        self.d4 = rm.Dense(400)
        self.d5 = rm.Dense(784)

    def encode(self, x):
        z = rm.relu(self.e1(x))
        z = rm.relu(self.e2(z))
        z = rm.relu(self.e3(z))
        z = rm.relu(self.e4(z))
        return self.em(z), self.ev(z)

    # sample z from mu & log(sigma^2)
    def reparametrize(self, mu, logvar):
        std = rm.exp(logvar * 0.5)
        eps = np.random.randn(logvar.shape[0], logvar.shape[1])
        return mu + eps * std

    def decode(self, x):
        z = rm.relu(self.d1(x))
        z = rm.relu(self.d2(z))
        z = rm.relu(self.d3(z))
        z = rm.relu(self.d4(z))
        return rm.sigmoid(self.d5(z))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar

損失関数の定義

VAEの損失関数は2つの項から成ります。

\begin{equation} l_{\rm total} = l_{\rm recon} + D_{KL}(q_{\phi}(z \mid x) \| p_{\theta}(z)) \end{equation}

右辺第1項の l_{\rm recon} は入力画像と再構成画像の間の誤差を表し、これは通常のオートエンコーダで定義する損失関数と同様のものです。第2項は事前分布 p_{\theta}(z) から事後分布 q_{\theta}(z \mid x) へのKLダイバージェンスです。

事前分布 p_{\theta}(z) および事後分布 q_{\phi}(z\mid x) が共にガウス分布であるような簡単な場合を考えると、KLダイバージェンスは解析的に計算することができ、以下のようになります。

\begin{equation} D_{KL}(q_{\phi}(z \mid x) \| p_{\theta}(z)) = - \frac{1}{2} \sum_{j=1}^J (1 + \log((\sigma_j)^2) - (\mu_j)^2 - (\sigma_j)^2) \end{equation}

ここで J は潜在空間 z の次元数で j はそれぞれの次元を表します。

In [12]:
def loss_function(recon_x, x, mu, logvar, batch_size):
    # Mean-squared-error between reconstructed images & input images
    MSE = rm.mean_squared_error(recon_x, x)
    # KL-divergence from the prior distribution p(z) to the posterior q(z|x)
    KLD = - 0.5 * rm.sum(1 + logvar - mu ** 2 - rm.exp(logvar)) / (batch_size)
    total_loss = MSE + KLD
    return MSE, KLD, total_loss

モデルの学習

In [13]:
# initialize the VAE model assuming 2 latent-dim
model = VAE(2)

# define optimizer function
optimizer = rm.Adam()
In [14]:
# set parameters
EPOCH = 60
BATCH_SIZE = 200
TEST_BATCH_SIZE = 1000
N = X_train.shape[0] # train data size
M = X_test.shape[0]  # test data size
num_batch_train = N // BATCH_SIZE
num_batch_test  = M // TEST_BATCH_SIZE

# learning curves
train_MSE_curve, train_KLD_curve, train_loss_curve = [], [], []
test_MSE_curve, test_KLD_curve, test_loss_curve = [], [], []

total_start = time.time()
# train loop
for i in range(1, 1+EPOCH):

    start = time.time()
    perm = np.random.permutation(N)
    train_MSE ,train_KLD ,train_loss = 0, 0, 0

    for j in range(num_batch_train):
        X_batch = X_train[perm[j*BATCH_SIZE:(j+1)*BATCH_SIZE]]

        # forward propagation
        with model.train():
            recon_batch, mu, logvar = model(X_batch)
            mse, kld, loss = loss_function(recon_batch, X_batch, mu, logvar, BATCH_SIZE)

        # backpropagation
        loss.grad().update(optimizer)

        train_MSE += mse.as_ndarray()
        train_KLD += kld.as_ndarray()
        train_loss += loss.as_ndarray()

    train_MSE /= num_batch_train
    train_KLD /= num_batch_train
    train_loss /= num_batch_train

    # test validation
    test_MSE, test_KLD, test_loss = 0, 0, 0
    for j in range(num_batch_test):
        test_batch = X_test[j*TEST_BATCH_SIZE:(j+1)*TEST_BATCH_SIZE]
        recon_test, mu_test, logvar_test = model(test_batch)
        mse, kld, loss = loss_function(recon_test, test_batch, mu_test, logvar_test, TEST_BATCH_SIZE)
        test_MSE  += mse.as_ndarray()
        test_KLD  += kld.as_ndarray()
        test_loss += loss.as_ndarray()

    test_MSE /= num_batch_test
    test_KLD /= num_batch_test
    test_loss /= num_batch_test

    # record a log
    train_MSE_curve.append(train_MSE)
    train_KLD_curve.append(train_KLD)
    train_loss_curve.append(train_loss)
    test_MSE_curve.append(test_MSE)
    test_KLD_curve.append(test_KLD)
    test_loss_curve.append(test_loss)

    elapsed = time.time() - start
    print("epoch %02d - %.1fs - loss %f - test_loss %f (MSE %f | KLD %f)"
          % (i, elapsed, train_loss, test_loss, test_MSE, test_KLD))
epoch 01 - 31.3s - loss 29.292006 - test_loss 23.355804 (MSE 20.918629 | KLD 2.437176)
epoch 02 - 26.0s - loss 21.881132 - test_loss 21.035431 (MSE 18.023841 | KLD 3.011590)
epoch 03 - 26.0s - loss 20.157909 - test_loss 19.423353 (MSE 15.813090 | KLD 3.610263)
epoch 04 - 26.0s - loss 18.991537 - test_loss 18.824402 (MSE 15.328443 | KLD 3.495959)
epoch 05 - 26.0s - loss 18.368986 - test_loss 18.323238 (MSE 14.764117 | KLD 3.559122)
epoch 06 - 26.0s - loss 17.861025 - test_loss 17.551872 (MSE 13.947551 | KLD 3.604319)
epoch 07 - 26.0s - loss 17.403337 - test_loss 17.386518 (MSE 13.809210 | KLD 3.577309)
epoch 08 - 26.1s - loss 17.022255 - test_loss 17.088886 (MSE 13.103191 | KLD 3.985694)
epoch 09 - 25.5s - loss 16.702307 - test_loss 16.500235 (MSE 12.550611 | KLD 3.949624)
epoch 10 - 25.8s - loss 16.450447 - test_loss 16.399370 (MSE 12.371619 | KLD 4.027751)
epoch 11 - 25.8s - loss 16.361910 - test_loss 16.549110 (MSE 12.418163 | KLD 4.130946)
epoch 12 - 25.9s - loss 16.281139 - test_loss 16.272774 (MSE 12.218720 | KLD 4.054053)
epoch 13 - 25.8s - loss 16.189659 - test_loss 16.299711 (MSE 12.575751 | KLD 3.723960)
epoch 14 - 25.8s - loss 16.111042 - test_loss 16.202089 (MSE 12.203213 | KLD 3.998876)
epoch 15 - 25.8s - loss 16.063742 - test_loss 16.148630 (MSE 11.959878 | KLD 4.188750)
epoch 16 - 25.8s - loss 16.152796 - test_loss 16.300724 (MSE 12.155709 | KLD 4.145014)
epoch 17 - 25.8s - loss 15.969066 - test_loss 16.214321 (MSE 11.820733 | KLD 4.393588)
epoch 18 - 25.8s - loss 15.954565 - test_loss 15.959585 (MSE 11.853073 | KLD 4.106514)
epoch 19 - 25.8s - loss 15.878405 - test_loss 15.907331 (MSE 11.661326 | KLD 4.246005)
epoch 20 - 25.8s - loss 15.864445 - test_loss 16.062075 (MSE 11.917901 | KLD 4.144173)
epoch 21 - 25.8s - loss 15.860448 - test_loss 16.399574 (MSE 12.238612 | KLD 4.160963)
epoch 22 - 25.8s - loss 15.803229 - test_loss 15.913763 (MSE 11.692834 | KLD 4.220929)
epoch 23 - 25.8s - loss 15.765581 - test_loss 16.204538 (MSE 11.937332 | KLD 4.267207)
epoch 24 - 25.8s - loss 16.012980 - test_loss 16.049160 (MSE 11.706032 | KLD 4.343127)
epoch 25 - 25.8s - loss 15.797098 - test_loss 15.893946 (MSE 11.503478 | KLD 4.390470)
epoch 26 - 25.8s - loss 15.674995 - test_loss 15.865252 (MSE 11.529056 | KLD 4.336195)
epoch 27 - 25.8s - loss 15.779749 - test_loss 15.717323 (MSE 11.648937 | KLD 4.068388)
epoch 28 - 25.8s - loss 15.614095 - test_loss 15.723486 (MSE 11.553985 | KLD 4.169500)
epoch 29 - 25.8s - loss 15.593273 - test_loss 15.731768 (MSE 11.340910 | KLD 4.390857)
epoch 30 - 25.8s - loss 15.603667 - test_loss 15.585861 (MSE 11.337660 | KLD 4.248201)
epoch 31 - 25.8s - loss 15.550684 - test_loss 15.630290 (MSE 11.385874 | KLD 4.244417)
epoch 32 - 25.8s - loss 15.557311 - test_loss 15.800154 (MSE 11.505608 | KLD 4.294548)
epoch 33 - 25.8s - loss 15.545179 - test_loss 15.559176 (MSE 11.300625 | KLD 4.258552)
epoch 34 - 25.8s - loss 15.531303 - test_loss 15.525171 (MSE 11.286374 | KLD 4.238796)
epoch 35 - 25.8s - loss 15.474730 - test_loss 15.715027 (MSE 11.472173 | KLD 4.242855)
epoch 36 - 25.8s - loss 15.492939 - test_loss 15.683006 (MSE 11.362078 | KLD 4.320929)
epoch 37 - 25.8s - loss 15.436321 - test_loss 15.597822 (MSE 11.351393 | KLD 4.246430)
epoch 38 - 25.8s - loss 15.456154 - test_loss 15.519873 (MSE 11.285704 | KLD 4.234169)
epoch 39 - 25.8s - loss 15.513895 - test_loss 15.487173 (MSE 10.995080 | KLD 4.492094)
epoch 40 - 25.8s - loss 15.501789 - test_loss 15.558016 (MSE 11.280224 | KLD 4.277792)
epoch 41 - 25.8s - loss 15.463036 - test_loss 15.468918 (MSE 11.109931 | KLD 4.358987)
epoch 42 - 25.8s - loss 15.485232 - test_loss 15.594754 (MSE 11.431299 | KLD 4.163455)
epoch 43 - 25.8s - loss 15.384259 - test_loss 15.390335 (MSE 11.046783 | KLD 4.343551)
epoch 44 - 25.8s - loss 15.413895 - test_loss 15.426082 (MSE 11.096325 | KLD 4.329757)
epoch 45 - 25.8s - loss 15.414681 - test_loss 15.417727 (MSE 10.920183 | KLD 4.497544)
epoch 46 - 25.8s - loss 15.416509 - test_loss 15.622869 (MSE 11.132002 | KLD 4.490869)
epoch 47 - 25.8s - loss 15.489601 - test_loss 15.560633 (MSE 11.164411 | KLD 4.396220)
epoch 48 - 25.8s - loss 15.477632 - test_loss 15.449161 (MSE 11.210119 | KLD 4.239041)
epoch 49 - 25.8s - loss 15.398214 - test_loss 15.399889 (MSE 11.240342 | KLD 4.159546)
epoch 50 - 25.8s - loss 15.383414 - test_loss 15.408064 (MSE 11.131777 | KLD 4.276286)
epoch 51 - 25.8s - loss 15.435960 - test_loss 15.378553 (MSE 10.955298 | KLD 4.423256)
epoch 52 - 25.8s - loss 15.419683 - test_loss 15.494047 (MSE 11.012778 | KLD 4.481270)
epoch 53 - 25.8s - loss 15.402197 - test_loss 15.588193 (MSE 11.174954 | KLD 4.413240)
epoch 54 - 25.8s - loss 15.355996 - test_loss 15.387619 (MSE 11.015935 | KLD 4.371685)
epoch 55 - 25.8s - loss 15.392119 - test_loss 15.418429 (MSE 10.925451 | KLD 4.492978)
epoch 56 - 25.8s - loss 15.346709 - test_loss 15.429564 (MSE 11.111563 | KLD 4.318002)
epoch 57 - 25.8s - loss 15.316667 - test_loss 15.386075 (MSE 10.964458 | KLD 4.421617)
epoch 58 - 25.8s - loss 15.363145 - test_loss 15.495132 (MSE 11.136416 | KLD 4.358716)
epoch 59 - 25.8s - loss 15.321315 - test_loss 15.304787 (MSE 10.864760 | KLD 4.440026)
epoch 60 - 25.8s - loss 15.303391 - test_loss 15.255893 (MSE 10.834188 | KLD 4.421706)
In [15]:
plt.figure(figsize=(16, 6))
x = range(1, EPOCH+1)
plt.plot(x, train_MSE_curve, "b:", label="train MSE")
plt.plot(x, train_KLD_curve, "b--", label="train KLD")
plt.plot(x, train_loss_curve, "b", label="train loss")
plt.plot(x, test_MSE_curve, "r:", label="test MSE")
plt.plot(x, test_KLD_curve, "r--", label="test KLD")
plt.plot(x, test_loss_curve, "r", label="test loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid()
../../../_images/notebooks_generative-model_VAE_notebook_36_0.png

KLダイバージェンスは正則化項として働いています。

再構成画像

In [16]:
recon_img, mu, logvar = model(X_test)
recon_img = recon_img.as_ndarray()
mu = mu.as_ndarray()
imshow(recon_img[:40])
../../../_images/notebooks_generative-model_VAE_notebook_39_0.png

対応する元の入力画像

In [17]:
imshow(X_test[:40])
../../../_images/notebooks_generative-model_VAE_notebook_41_0.png

潜在変数空間の確認

In [18]:
test_label = df_test.label.values
markers = ["*", "s", "v", "^", "+", "o", "P", "D", "p", "d"]
colors = []
cm = plt.get_cmap('jet')
for i in range(10):
    colors.append(cm(1.*i/10))
In [19]:
plt.figure(figsize=(12, 12))
for i, color, marker, label_name in zip(list(range(10)), colors, markers, label_names):
    plt.scatter(mu[test_label==i, 0], mu[test_label==i, 1],
                c=color, marker=marker, label=label_name)
plt.legend()
plt.grid()
../../../_images/notebooks_generative-model_VAE_notebook_44_0.png

上の図から以下のことが分かります。- 教師なし学習にも関わらず、パンツやドレス、ブーツなどのクラスは綺麗なガウス分布を形成している- サンダル・スニーカー・ブーツが全体として一つの「靴」というクラスタを形成している「シャツ」の分布はプルオーバー・コート・Tシャツ・ドレスなどの他の上着の分布全体を覆うように分布している

In [20]:
size = 20
xrange = np.linspace(2, -2, size)
yrange = np.linspace(-2, 2, size)
zmap = []

for i, x in enumerate(xrange):
    for j, y in enumerate(yrange):
        zmap.append([y, x])

output = model.decode(np.array(zmap))
In [21]:
fig, ax = plt.subplots(size, size, figsize=(10, 10))
ax = ax.ravel()
plt.tight_layout(False)
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0)
for i in range(size**2):
    ax[i].imshow(1-output[i].as_ndarray().reshape(28, 28), cmap="gray")
    ax[i].set_xticks([])
    ax[i].set_yticks([])
../../../_images/notebooks_generative-model_VAE_notebook_47_0.png
In [23]:
# Save weights for reproducibility
model.save("vae.hd5")