U-Net: Biomedical Image セグメンテーションのためのConvolutionalネットワーク

本章では、最も有名なセマンティックセグメンテーションのモデルの一つであるU-Netを紹介します。このモデルはバイオメディカル画像などに対して欲使われますが、その他様々な画像に対しても応用可能です。

必要なライブラリ

本章では以下のライブラリを必要とします。

  • matplotlib
  • numpy
  • scipy
  • tqdm
  • cv2(opencv-python)
In [1]:
import sys, os
import numpy as np
import renom as rm
from random import shuffle
import cv2
from renom.cuda import set_cuda_active
import matplotlib.pylab as plt
from tqdm import tqdm

GPUを用いた計算

GPUを用いて計算を行う場合、 set_cuda_active() を用いて、cudaを有効化します.その際、引数に True を渡す必要があります。

In [2]:
set_cuda_active(True)

# define your prefix path
prefix_path = '/home/jun/projects/data/'

データのロード

データセットから画像データとラベルデータを取得するための load_data 関数を定義します。引数のmodeにはデータセット内のディレクトリ名が入ります。

本章では、Camvidデータセットと呼ばれるものを利用します。このデータセットはhttp://docs.renom.jp/downloads/CamVid.zipでダウンロードをすることができます。ダウンロード後、画像データの正規化をする必要があります。ラベルデータはマスク、つまり各クラスのidが画像のそれぞれのピクセルに対して割り当てられています。したがって、このclass idをOne Hot Vectorに変換する必要があります。これらの作業により、最終的に (#クラス, 幅、高さ) となります。

In [3]:
data_path = './CamVid/'

def normalized(rgb):
    return rgb / 255.

def one_hot_it(labels,w,h):
    x = np.zeros([w,h,12])
    for i in range(w):
        for j in range(h):
            x[i,j,labels[i][j]]=1
    return x

def load_data(mode):
    data = []
    label = []
    with open(data_path + mode +'.txt') as f:
        txt = f.readlines()
        txt = [line.split(' ') for line in txt]
    for i in range(len(txt)):
        data.append(np.rollaxis(normalized(cv2.imread(data_path+txt[i][0][15:])[136:,256:]),2))
        label.append(one_hot_it(cv2.imread(data_path + txt[i][1][15:][:-1])[136:,256:][:,:,0],224,224))
    return np.array(data), np.array(label)

train_images, train_label = load_data("train")
test_images, test_label = load_data("test")
val_images, val_label = load_data("val")

train_label = np.transpose(train_label, (0, 3, 1, 2))

test_label = np.transpose(test_label, (0, 3, 1, 2))

val_label = np.transpose(val_label, (0, 3, 1, 2))

可視化

各クラスに対する色を定義します。 visualize 関数ではU-Netから予測された出力を変換して、One Hot Vectorではなくクラスのidが割り当てられた対応するセグメンテーション画像に変換します。

In [4]:
Sky = [128,128,128]
Building = [128,0,0]
Pole = [192,192,128]
Road = [128,64,128]
Pavement = [60,40,222]
Tree = [128,128,0]
SignSymbol = [192,128,128]
Fence = [64,64,128]
Car = [64,0,128]
Pedestrian = [64,64,0]
Bicyclist = [0,128,192]
Unlabelled = [0,0,0]

label_name = ['Sky', 'Building', 'Pole', 'Road', 'Pavement', 'Tree', 'SignSymbol', 'Fence', 'Car',
              'Pedestrial', 'Bicyclist', 'Unlabelled']
label_colors = np.array([Sky, Building, Pole, Road, Pavement,
                          Tree, SignSymbol, Fence, Car, Pedestrian, Bicyclist, Unlabelled])

def visualize(temp):
    r = temp.copy()
    g = temp.copy()
    b = temp.copy()
    for l in range(0,11):
        r[temp==l]=label_colors[l,0]
        g[temp==l]=label_colors[l,1]
        b[temp==l]=label_colors[l,2]

    rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
    rgb[:,:,0] = (r/255.0)#[:,:,0]
    rgb[:,:,1] = (g/255.0)#[:,:,1]
    rgb[:,:,2] = (b/255.0)#[:,:,2]
    return rgb

色とクラス

In [5]:
fig = plt.figure(figsize=(10, 10))
for i in range(len(label_colors)):
    fig.add_subplot(3, 6, i+1)
    plt.title(label_name[i])
    plt.imshow(np.broadcast_to(label_colors[i], (12 ,12 ,3))/255.)
plt.show()
../../../_images/notebooks_image_processing_u-net_notebook_11_0.png

マスクのargmaxを取ることにより、One Hot Vectorを対応するクラスのidに変換します。 visualize 関数はこの変換されたマスクを引数として受け取ります。

In [6]:
index = np.random.randint(len(train_images))
fig=plt.figure(figsize=(8, 8))
fig.add_subplot(1, 2, 1)
plt.imshow(train_images[index].transpose((1, 2, 0)), cmap='gray')
fig.add_subplot(1, 2, 2)
plt.imshow(visualize(np.argmax(train_label[index].transpose((1, 2, 0)), axis=2)))
plt.show()


../../../_images/notebooks_image_processing_u-net_notebook_13_0.png

Generator

学習前に、画像の前処理等を効率よく行えるようGeneratorのクラスを作成します。本章では、主に2つのオーグメンテーション: 縦にフリップ、横にフリップ、を利用します。

In [7]:
class Generator(object):
    def __init__(self, images, masks, val_images, val_masks, batch_size):
        self.images = images
        self.masks = masks
        self.val_images = val_images
        self.val_masks = val_masks
        self.batch_size = batch_size

    def hflip(self, img, y):
        if np.random.random() < 0.5:
            img = img[:, ::-1]
            y = y[:, ::-1]
        return img, y

    def vflip(self, img, y):
        if np.random.random() < 0.5:
            img = img[::-1]
            y = y[::-1]
        return img, y

    def generate(self, train=True):
        while True:
            if train:
                perm = np.random.permutation(len(self.images))
            else:
                perm = np.random.permutation(len(self.val_images))

            inputs = []
            targets = []
            for i in perm:
                if train:
                    img = self.images[i]
                    mask = self.masks[i]
                    img, mask = self.hflip(img, mask)
                    img, mask = self.vflip(img, mask)
                else:
                    img = self.val_images[i]
                    mask = self.val_masks[i]
                inputs.append(img)
                targets.append(mask)

                if len(targets) == self.batch_size:
                    tmp_inp = np.array(inputs)
                    tmp_targets = np.array(targets)
                    inputs = []
                    targets = []
                    yield tmp_inp, tmp_targets

U-Net

U-Netを構築するために、畳み込み層、逆畳込み層、Max プーリング, そしてRelu (活性化関数)を使います。畳み込み層による処理のあと、逆畳み込み層を利用することでアップサンプリングを行います。それぞれの畳み込み層の後にはRelu関数を利用します。

In [8]:
class UNet(rm.Model):
    def __init__(self, num_classes):
        self.conv1_1 = rm.Conv2d(64, padding=1, filter=3)
        self.conv1_2 = rm.Conv2d(64, padding=1, filter=3)
        self.conv2_1 = rm.Conv2d(128, padding=1, filter=3)
        self.conv2_2 = rm.Conv2d(128, padding=1, filter=3)
        self.conv3_1 = rm.Conv2d(256, padding=1, filter=3)
        self.conv3_2 = rm.Conv2d(256, padding=1, filter=3)
        self.conv4_1 = rm.Conv2d(512, padding=1, filter=3)
        self.conv4_2 = rm.Conv2d(512, padding=1, filter=3)
        self.conv5_1 = rm.Conv2d(1024, padding=1, filter=3)
        self.conv5_2 = rm.Conv2d(1024, padding=1, filter=3)

        self.deconv1 = rm.Deconv2d(512, stride=2)
        self.conv6_1 = rm.Conv2d(256, padding=1)
        self.conv6_2 = rm.Conv2d(256, padding=1)
        self.deconv2 = rm.Deconv2d(256, stride=2)
        self.conv7_1 = rm.Conv2d(128, padding=1)
        self.conv7_2 = rm.Conv2d(128, padding=1)
        self.deconv3 = rm.Deconv2d(128, stride=2)
        self.conv8_1 = rm.Conv2d(64, padding=1)
        self.conv8_2 = rm.Conv2d(64, padding=1)
        self.deconv4 = rm.Deconv2d(64, stride=2)
        self.conv9 = rm.Conv2d(num_classes, filter=1)



    def forward(self, x):
        t = rm.relu(self.conv1_1(x))
        c1 = rm.relu(self.conv1_2(t))
        t = rm.max_pool2d(c1, filter=2, stride=2)
        t = rm.relu(self.conv2_1(t))
        c2 = rm.relu(self.conv2_2(t))
        t = rm.max_pool2d(c2, filter=2, stride=2)
        t = rm.relu(self.conv3_1(t))
        c3 = rm.relu(self.conv3_2(t))
        t = rm.max_pool2d(c3, filter=2, stride=2)
        t = rm.relu(self.conv4_1(t))
        c4 = rm.relu(self.conv4_2(t))
        t = rm.max_pool2d(c4, filter=2, stride=2)
        t = rm.relu(self.conv5_1(t))
        t = rm.relu(self.conv5_2(t))

        t = self.deconv1(t)[:, :, :c4.shape[2], :c4.shape[3]]
        t = rm.concat([c4, t])
        t = rm.relu(self.conv6_1(t))
        t = rm.relu(self.conv6_2(t))
        t = self.deconv2(t)[:, :, :c3.shape[2], :c3.shape[3]]
        t = rm.concat([c3, t])

        t = rm.relu(self.conv7_1(t))
        t = rm.relu(self.conv7_2(t))
        t = self.deconv3(t)[:, :, :c2.shape[2], :c2.shape[3]]
        t = rm.concat([c2, t])

        t = rm.relu(self.conv8_1(t))
        t = rm.relu(self.conv8_2(t))
        t = self.deconv4(t)[:, :, :c1.shape[2], :c1.shape[3]]
        t = rm.concat([c1, t])

        t = self.conv9(t)

        return t

学習

本章ではバッチサイス2を利用します。しかしバッチサイズはGPUの容量が許す限り好きなように設定することができます。UNetクラスのインスタンスを呼ぶ際は、引数にクラスの数を渡さなければなりません。

その他の設定は下の通りです。

  • 勾配法: Adam
  • 学習係数: 1e-3
  • エポック: 30
In [ ]:
num_classes = 12
batch_size = 2

unet = UNet(num_classes)
gen = Generator(train_images, train_label, val_images, val_label, batch_size)
best_loss = np.inf
epochs = 30
N = len(train_images)
img_shape = (512, 512)
opt = rm.Adam(lr=1e-3)
best_loss = np.inf
for epoch in range(epochs):
    loss = 0
    val_loss = 0
    bar = tqdm(range(N//batch_size))
    for i in range(N//batch_size):
        with unet.train():
            x, mask = gen.generate(True).__next__()
            t = unet(x)
            l = rm.softmax_cross_entropy(t, mask) / (img_shape[0]*img_shape[1])
        l.grad().update(opt)
        bar.set_description("epoch {:03d} train loss:{:6.4f}".format(epoch, float(l.as_ndarray())))
        loss += l.as_ndarray()
        bar.update(1)

    for j in range(len(val_images)//batch_size):
        x, mask = gen.generate(False).__next__()
        t = unet(x)
        val_l = rm.softmax_cross_entropy(t, mask) / (img_shape[0]*img_shape[1])
        val_loss += val_l.as_ndarray()
    if val_loss/(j+1) < best_loss:
        unet.save('weight.best2.h5')

    bar.set_description("epoch {:03d} avg loss:{:6.4f} val loss:{:6.4f}".format(epoch, float((loss/(i+1))), float((val_loss/(j+1)))))
    bar.update(0)
    bar.refresh()
    bar.close()

予測結果の可視化

バリデーション画像とその予測結果の表示をします。

In [19]:
images = val_images
In [20]:
print('------- priginal image ------- | -------- prediction --------')
for img in images[:10]:
    t = unet(np.array([img]))

    fig = plt.figure(figsize=(8, 8))
    fig.add_subplot(1, 2, 1)
    plt.imshow(img.transpose((1, 2, 0)))
    fig.add_subplot(1, 2, 2)
    plt.imshow(visualize(np.argmax(t[0].as_ndarray().transpose((1, 2, 0)), axis=2)))
    plt.show()
------- priginal image ------- | -------- prediction --------
../../../_images/notebooks_image_processing_u-net_notebook_22_1.png
../../../_images/notebooks_image_processing_u-net_notebook_22_2.png
../../../_images/notebooks_image_processing_u-net_notebook_22_3.png
../../../_images/notebooks_image_processing_u-net_notebook_22_4.png
../../../_images/notebooks_image_processing_u-net_notebook_22_5.png
../../../_images/notebooks_image_processing_u-net_notebook_22_6.png
../../../_images/notebooks_image_processing_u-net_notebook_22_7.png
../../../_images/notebooks_image_processing_u-net_notebook_22_8.png
../../../_images/notebooks_image_processing_u-net_notebook_22_9.png
../../../_images/notebooks_image_processing_u-net_notebook_22_10.png