Shift Preprocessing for Image Classification

Shift preprocessing for image classification using caltech dataset

We have to preprocess the images to get the robust learning.
There are many processing method, zoom, color jitter, shift, et al.
Shift is useful preprocessing in the case that the position of the object has a large variance.
So, we introduce how to shift the input image and make batch data.

Data reference is as bellow.

Caltech 101
L. Fei-Fei, R. Fergus and P. Perona. One-Shot learning of object categories. IEEE Trans. Pattern Recognition and Machine Intelligence. In press.
In [1]:
from renom.utility.distributor import ImageClassificationDistributor
from renom.utility.distributor.imageloader import ImageLoader
from renom.utility.image import *
import matplotlib.pyplot as plt
import numpy as np
import math
import os

Load for classification

Load the image for classification

In [2]:
def load_for_classification(path):
    class_list = os.listdir(path)
    onehot_vectors = []
    for i in range(len(class_list)):
        temp = [0] * len(class_list)
        temp[i] = 1
        onehot_vectors.append(temp)
    X_list = []
    y_list = []
    for classname in class_list:
        imglist = os.listdir(path + classname)
        for filename in imglist:
            filepath = path + classname + "/" + filename
            X_list.append(filepath)
            onehot = onehot_vectors[class_list.index(classname)]
            y_list.append(onehot)

    return X_list, y_list, class_list

Get class from onehot

Get the class name from onehot vector

In [3]:
def get_class_from_onehot(onehot_vector, class_list):
    return class_list[onehot_vector.index(1)]

Show batch images

Calculate the batch size, and appropriate map size.
So, you can confirm that the difference between original input images and prerocessed images.
In [4]:
def imshow_batch(images):
    n_images = images.shape[0]
    images = images.astype(np.uint8)
    map_width = int(math.ceil(math.sqrt(n_images)))
    map_height = int(math.ceil(n_images / map_width))
    for h in range(map_height):
        for w in range(map_width):
            if (w==0) and (h*map_width+w < n_images):
                temp_concatenated_image = images[h*map_width+w]
            elif h*map_width+w < n_images:
                temp_concatenated_image = np.concatenate(
                      (temp_concatenated_image, images[h*map_width+w]), axis=1)
            elif (w==0) and (h*map_width+w >= n_images):
                temp_concatenated_image = np.zeros(
                      images[0].shape, dtype=images[0].dtype)
            else:
                temp_concatenated_image = np.concatenate((temp_concatenated_image, np.zeros(
                         images[0].shape, dtype=images[0].dtype)), axis=1)
        if h==0:
            concatenated_image = temp_concatenated_image
        else:
            concatenated_image = np.concatenate(
                    (concatenated_image, temp_concatenated_image), axis=0)

    fig, ax = plt.subplots(figsize=(15, 15))
    ax.imshow(concatenated_image)
    plt.axis("off")
    plt.tight_layout()
    plt.show()

Load the image data and Label data

We make the one hot label data based on the directory names.

Shift((20,50)),
This part is for shifting of the image.
Shift size is maximum shift rate, image is shifted based on shift scale(from -1 to 1).
In [5]:
path = "101_ObjectCategories/"
X_list, Y_list, class_list = load_for_classification(path)

augmentation = DataAugmentation([\
                 Shift((20,50)),
                 ],
                 random = True)
distributer = ImageClassificationDistributor(image_path_list=X_list,
                                             y_list = Y_list,
                                             class_list = class_list,
                                             imsize=(200, 200),
                                             color="RGB",
                                             augmentation=augmentation)
for (x,y) in distributer.batch(12, shuffle=True):
    for label in y:
        print(get_class_from_onehot(label.tolist(), class_list))
    images = x.transpose((0,2,3,1))
    imshow_batch(x.transpose((0,2,3,1)))
    break
pyramid
soccer_ball
Motorbikes
airplanes
hawksbill
Motorbikes
tick
platypus
Motorbikes
ferry
watch
electric_guitar
../../../_images/notebooks_preprocessing_shift_imageclassification_notebook_9_1.png