U-Net: Convolutional Networks for Biomedical Image Segmentation

In this section, we introduce one the most famous semantic segmentation model called U-Net. This model is often used for biomedical image segmentation, though, the model is well applicable for various problems.

Required Library

In this tutorial, following modules are required.

  • matplotlib
  • numpy
  • scipy
  • tqdm
  • cv2(opencv-python)
In [1]:
import sys, os
import numpy as np
import renom as rm
from random import shuffle
import cv2
from renom.cuda import set_cuda_active
import matplotlib.pylab as plt
from tqdm import tqdm

GPU-enabled Computing

If you wish to use GPUs, you need to call the set_cuda_active() with the argument True . This makes training much faster than training with only CPUs. Before calling this function, you need to make sure if you have GPU on your machine.

In [2]:
set_cuda_active(True)

# define your prefix path
prefix_path = '/home/jun/projects/data/'

Load data

We define load_data function to fetch image data and label data from dataset. The argument of mode corresponds to the directory name.

we use Camvid dataset which can be downloaded from http://docs.renom.jp/downloads/CamVid.zip . After dowonloading CamVid datasets, you need to normalize the image data. A label data is a mask, which means a class id is assigned to each pixel. Thus, we basically have to convert the class id to one hot vector for each pixel. Through this process, the shape of the target data becomes (#classes, width, height) .

In [3]:
data_path = './CamVid/'

def normalized(rgb):
    return rgb / 255.

def one_hot_it(labels,w,h):
    x = np.zeros([w,h,12])
    for i in range(w):
        for j in range(h):
            x[i,j,labels[i][j]]=1
    return x

def load_data(mode):
    data = []
    label = []
    with open(data_path + mode +'.txt') as f:
        txt = f.readlines()
        txt = [line.split(' ') for line in txt]
    for i in range(len(txt)):
        data.append(np.rollaxis(normalized(cv2.imread(data_path+txt[i][0][15:])[136:,256:]),2))
        label.append(one_hot_it(cv2.imread(data_path + txt[i][1][15:][:-1])[136:,256:][:,:,0],224,224))
    return np.array(data), np.array(label)

train_images, train_label = load_data("train")
test_images, test_label = load_data("test")
val_images, val_label = load_data("val")

train_label = np.transpose(train_label, (0, 3, 1, 2))

test_label = np.transpose(test_label, (0, 3, 1, 2))

val_label = np.transpose(val_label, (0, 3, 1, 2))

Visualization

We define colors for each class as follows. The visualize function converts a predicted output by the U-Net to a corresponding segmentation image.

In [4]:
Sky = [128,128,128]
Building = [128,0,0]
Pole = [192,192,128]
Road = [128,64,128]
Pavement = [60,40,222]
Tree = [128,128,0]
SignSymbol = [192,128,128]
Fence = [64,64,128]
Car = [64,0,128]
Pedestrian = [64,64,0]
Bicyclist = [0,128,192]
Unlabelled = [0,0,0]

label_name = ['Sky', 'Building', 'Pole', 'Road', 'Pavement', 'Tree', 'SignSymbol', 'Fence', 'Car',
              'Pedestrial', 'Bicyclist', 'Unlabelled']
label_colors = np.array([Sky, Building, Pole, Road, Pavement,
                          Tree, SignSymbol, Fence, Car, Pedestrian, Bicyclist, Unlabelled])

def visualize(temp):
    r = temp.copy()
    g = temp.copy()
    b = temp.copy()
    for l in range(0,11):
        r[temp==l]=label_colors[l,0]
        g[temp==l]=label_colors[l,1]
        b[temp==l]=label_colors[l,2]

    rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
    rgb[:,:,0] = (r/255.0)#[:,:,0]
    rgb[:,:,1] = (g/255.0)#[:,:,1]
    rgb[:,:,2] = (b/255.0)#[:,:,2]
    return rgb

Relationship of color and class

In [5]:
fig = plt.figure(figsize=(10, 10))
for i in range(len(label_colors)):
    fig.add_subplot(3, 6, i+1)
    plt.title(label_name[i])
    plt.imshow(np.broadcast_to(label_colors[i], (12 ,12 ,3))/255.)
plt.show()
../../../_images/notebooks_image_processing_u-net_notebook_11_0.png

Taking argmax of a mask, we convert a one-hot-vector to the corresponding class id. The visualize function takes this converted mask as an argument.

In [6]:
index = np.random.randint(len(train_images))
fig=plt.figure(figsize=(8, 8))
fig.add_subplot(1, 2, 1)
plt.imshow(train_images[index].transpose((1, 2, 0)), cmap='gray')
fig.add_subplot(1, 2, 2)
plt.imshow(visualize(np.argmax(train_label[index].transpose((1, 2, 0)), axis=2)))
plt.show()


../../../_images/notebooks_image_processing_u-net_notebook_13_0.png

Generator

We define a generator class to pre-process an image efficiently before traning steps. In this chapter, we use two augmentation techniques: vertical flip and horizontal flip.

In [7]:
class Generator(object):
    def __init__(self, images, masks, val_images, val_masks, batch_size):
        self.images = images
        self.masks = masks
        self.val_images = val_images
        self.val_masks = val_masks
        self.batch_size = batch_size

    def hflip(self, img, y):
        if np.random.random() < 0.5:
            img = img[:, ::-1]
            y = y[:, ::-1]
        return img, y

    def vflip(self, img, y):
        if np.random.random() < 0.5:
            img = img[::-1]
            y = y[::-1]
        return img, y

    def generate(self, train=True):
        while True:
            if train:
                perm = np.random.permutation(len(self.images))
            else:
                perm = np.random.permutation(len(self.val_images))

            inputs = []
            targets = []
            for i in perm:
                if train:
                    img = self.images[i]
                    mask = self.masks[i]
                    img, mask = self.hflip(img, mask)
                    img, mask = self.vflip(img, mask)
                else:
                    img = self.val_images[i]
                    mask = self.val_masks[i]
                inputs.append(img)
                targets.append(mask)

                if len(targets) == self.batch_size:
                    tmp_inp = np.array(inputs)
                    tmp_targets = np.array(targets)
                    inputs = []
                    targets = []
                    yield tmp_inp, tmp_targets

U-Net

To build U-Net, we use convolutional layer, deconvolutional layer (for upsampling), max pooling, and Relu (Activation function). After several convolutional layers, deconvolutinal layers are applied for upsampling. After each convoloution layer, we need to apply a relu activation function for the derived feature maps.

In [8]:
class UNet(rm.Model):
    def __init__(self, num_classes):
        self.conv1_1 = rm.Conv2d(64, padding=1, filter=3)
        self.conv1_2 = rm.Conv2d(64, padding=1, filter=3)
        self.conv2_1 = rm.Conv2d(128, padding=1, filter=3)
        self.conv2_2 = rm.Conv2d(128, padding=1, filter=3)
        self.conv3_1 = rm.Conv2d(256, padding=1, filter=3)
        self.conv3_2 = rm.Conv2d(256, padding=1, filter=3)
        self.conv4_1 = rm.Conv2d(512, padding=1, filter=3)
        self.conv4_2 = rm.Conv2d(512, padding=1, filter=3)
        self.conv5_1 = rm.Conv2d(1024, padding=1, filter=3)
        self.conv5_2 = rm.Conv2d(1024, padding=1, filter=3)

        self.deconv1 = rm.Deconv2d(512, stride=2)
        self.conv6_1 = rm.Conv2d(256, padding=1)
        self.conv6_2 = rm.Conv2d(256, padding=1)
        self.deconv2 = rm.Deconv2d(256, stride=2)
        self.conv7_1 = rm.Conv2d(128, padding=1)
        self.conv7_2 = rm.Conv2d(128, padding=1)
        self.deconv3 = rm.Deconv2d(128, stride=2)
        self.conv8_1 = rm.Conv2d(64, padding=1)
        self.conv8_2 = rm.Conv2d(64, padding=1)
        self.deconv4 = rm.Deconv2d(64, stride=2)
        self.conv9 = rm.Conv2d(num_classes, filter=1)



    def forward(self, x):
        t = rm.relu(self.conv1_1(x))
        c1 = rm.relu(self.conv1_2(t))
        t = rm.max_pool2d(c1, filter=2, stride=2)
        t = rm.relu(self.conv2_1(t))
        c2 = rm.relu(self.conv2_2(t))
        t = rm.max_pool2d(c2, filter=2, stride=2)
        t = rm.relu(self.conv3_1(t))
        c3 = rm.relu(self.conv3_2(t))
        t = rm.max_pool2d(c3, filter=2, stride=2)
        t = rm.relu(self.conv4_1(t))
        c4 = rm.relu(self.conv4_2(t))
        t = rm.max_pool2d(c4, filter=2, stride=2)
        t = rm.relu(self.conv5_1(t))
        t = rm.relu(self.conv5_2(t))

        t = self.deconv1(t)[:, :, :c4.shape[2], :c4.shape[3]]
        t = rm.concat([c4, t])
        t = rm.relu(self.conv6_1(t))
        t = rm.relu(self.conv6_2(t))
        t = self.deconv2(t)[:, :, :c3.shape[2], :c3.shape[3]]
        t = rm.concat([c3, t])

        t = rm.relu(self.conv7_1(t))
        t = rm.relu(self.conv7_2(t))
        t = self.deconv3(t)[:, :, :c2.shape[2], :c2.shape[3]]
        t = rm.concat([c2, t])

        t = rm.relu(self.conv8_1(t))
        t = rm.relu(self.conv8_2(t))
        t = self.deconv4(t)[:, :, :c1.shape[2], :c1.shape[3]]
        t = rm.concat([c1, t])

        t = self.conv9(t)

        return t

Training step

We use batch size of 1 in this chapter, but we can set whatever number you want for the batch size as long as your GPU is sufficient for its batch size. When you call an instance of UNet class, you have to assign the number of classes to the argument.

Other settings are as follows.

  • Optimizer: Adam
  • learning rate: 1e-3
  • Epochs: 30
In [ ]:
num_classes = 12
batch_size = 2

unet = UNet(num_classes)
gen = Generator(train_images, train_label, val_images, val_label, batch_size)
best_loss = np.inf
epochs = 30
N = len(train_images)
img_shape = (512, 512)
opt = rm.Adam(lr=1e-3)
best_loss = np.inf
for epoch in range(epochs):
    loss = 0
    val_loss = 0
    bar = tqdm(range(N//batch_size))
    for i in range(N//batch_size):
        with unet.train():
            x, mask = gen.generate(True).__next__()
            t = unet(x)
            l = rm.softmax_cross_entropy(t, mask) / (img_shape[0]*img_shape[1])
        l.grad().update(opt)
        bar.set_description("epoch {:03d} train loss:{:6.4f}".format(epoch, float(l.as_ndarray())))
        loss += l.as_ndarray()
        bar.update(1)

    for j in range(len(val_images)//batch_size):
        x, mask = gen.generate(False).__next__()
        t = unet(x)
        val_l = rm.softmax_cross_entropy(t, mask) / (img_shape[0]*img_shape[1])
        val_loss += val_l.as_ndarray()
    if val_loss/(j+1) < best_loss:
        unet.save('weight.best2.h5')

    bar.set_description("epoch {:03d} avg loss:{:6.4f} val loss:{:6.4f}".format(epoch, float((loss/(i+1))), float((val_loss/(j+1)))))
    bar.update(0)
    bar.refresh()
    bar.close()

Visualization of predicted segmentation image

We display the validation images and predicted outputs.

In [19]:
images = val_images
In [20]:
print('------- priginal image ------- | -------- prediction --------')
for img in images[:10]:
    t = unet(np.array([img]))

    fig = plt.figure(figsize=(8, 8))
    fig.add_subplot(1, 2, 1)
    plt.imshow(img.transpose((1, 2, 0)))
    fig.add_subplot(1, 2, 2)
    plt.imshow(visualize(np.argmax(t[0].as_ndarray().transpose((1, 2, 0)), axis=2)))
    plt.show()
------- priginal image ------- | -------- prediction --------
../../../_images/notebooks_image_processing_u-net_notebook_22_1.png
../../../_images/notebooks_image_processing_u-net_notebook_22_2.png
../../../_images/notebooks_image_processing_u-net_notebook_22_3.png
../../../_images/notebooks_image_processing_u-net_notebook_22_4.png
../../../_images/notebooks_image_processing_u-net_notebook_22_5.png
../../../_images/notebooks_image_processing_u-net_notebook_22_6.png
../../../_images/notebooks_image_processing_u-net_notebook_22_7.png
../../../_images/notebooks_image_processing_u-net_notebook_22_8.png
../../../_images/notebooks_image_processing_u-net_notebook_22_9.png
../../../_images/notebooks_image_processing_u-net_notebook_22_10.png