Weight Decay

How to use weight decay with ReNom using fully connected neural network model to mnist.

In this tutorial, we build a fully connected neural network model for clustering digit images. You can learn following points.

  • How to calculate the normalize term, weight decay.

Required libraries

In [1]:
from __future__ import division, print_function
import matplotlib.pyplot as plt
import numpy as np

from sklearn.datasets import fetch_mldata
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import confusion_matrix, classification_report

import renom as rm
from renom.optimizer import Sgd

Load data

Next, we have to load-in the raw, binary MNIST data and shape into training-ready objects. To accomplish this, we’ll use the fetch_mldata module included in the scikit-learn package.

The MNIST dataset consists of 70000 digit images. Before we do anything else, we have to split the data into a training set and a test set. We’ll then do two important pre-processing steps that make for a smoother training process: 1) Re-scale the image data (originaly integer values 0-255) to have a range from 0 to 1. 2) ‘’Binarize’’ the labels- map each digit (0-9) to a vector of 0s and 1s.

In [2]:
# Datapath must point to the directory containing the mldata folder.
data_path = "../dataset"
mnist = fetch_mldata('MNIST original', data_home=data_path)

X = mnist.data
y = mnist.target

# Rescale the image data to 0 ~ 1.
X = X.astype(np.float32)
X /= X.max()

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)
labels_train = LabelBinarizer().fit_transform(y_train).astype(np.float32)
labels_test = LabelBinarizer().fit_transform(y_test).astype(np.float32)

# Training data size.
N = len(X_train)

Neural network with a L2 normalization term

Weight decay is one of the regularization method.
The reasons why the overfitting problem occurs are many type, we think one of the reason is the unbalanced data.
Every epoch, the model update based on the backpropagation method.
The weight of the model changed sharply for the largest class.
So, we recommend you apply the weight decay when you can use the unbalanced data to restrict the change for updating.
In [3]:
class Mnist(rm.Model):

    def __init__(self):
        super(Mnist, self).__init__()
        self._layer1 = rm.Dense(100)
        self._layer2 = rm.Dense(10)

    def forward(self, x):
        out = self._layer2(rm.relu(self._layer1(x)))
        return out

    def weight_decay(self):
        weight_decay = rm.sum(self._layer1.params.w**2) + rm.sum(self._layer2.params.w**2)
        return weight_decay

Instantiation

In [4]:
# Choose neural network.
network = Mnist()

Training loop

Now that the network is built, we can start to do the actual training. Rather than using vanilla “batch” gradient descent, which is computationally expensive, we’ll use mini-batch stochastic gradient descent (SGD). This method trains on a handful of examples per iteration (the “batch-size”), allowing us to make “stochastic” updates to the weights in a short time. The learning curve will appear noisier, but this method tends to converge much faster.

In [5]:
# Hyper parameters
batch = 64
epoch = 10

optimizer = Sgd(lr = 0.1)

learning_curve = []
test_learning_curve = []

for i in range(epoch):
    perm = np.random.permutation(N)
    loss = 0
    for j in range(0, N // batch):
        train_batch = X_train[perm[j * batch:(j + 1) * batch]]
        responce_batch = labels_train[perm[j * batch:(j + 1) * batch]]

        # The computational graph is only generated for this block:
        with network.train():
            l = rm.softmax_cross_entropy(network(train_batch), responce_batch)
            if hasattr(network, "weight_decay"):
                l += 0.0001 * network.weight_decay()

        # Back propagation
        grad = l.grad()

        # Update
        grad.update(optimizer)

        # Changing type to ndarray is recommended.
        loss += l.as_ndarray()

    train_loss = loss / (N // batch)

    # Validation
    test_loss = rm.softmax_cross_entropy(network(X_test), labels_test).as_ndarray()
    test_learning_curve.append(test_loss)
    learning_curve.append(train_loss)
    print("epoch %03d train_loss:%f test_loss:%f"%(i, train_loss, test_loss))
epoch 000 train_loss:0.337253 test_loss:0.191109
epoch 001 train_loss:0.189557 test_loss:0.148555
epoch 002 train_loss:0.149692 test_loss:0.122226
epoch 003 train_loss:0.128814 test_loss:0.103989
epoch 004 train_loss:0.115948 test_loss:0.096587
epoch 005 train_loss:0.107144 test_loss:0.089462
epoch 006 train_loss:0.099740 test_loss:0.086061
epoch 007 train_loss:0.094767 test_loss:0.082598
epoch 008 train_loss:0.090281 test_loss:0.080142
epoch 009 train_loss:0.086551 test_loss:0.081469

Model evaluation

After training our model, we have to evaluate it. For each class (digit), we’ll use several scoring metrics: precision, recall, F1 score, and support, to get a full sense of how the model performs on our test data.

In [6]:
predictions = np.argmax(network(X_test).as_ndarray(), axis=1)

# Confusion matrix and classification report.
print(confusion_matrix(y_test, predictions))
print(classification_report(y_test, predictions))

# Learning curve.
plt.plot(learning_curve, linewidth=3, label="train")
plt.plot(test_learning_curve, linewidth=3, label="test")
plt.title("Learning curve")
plt.ylabel("error")
plt.xlabel("epoch")
plt.legend()
plt.grid()
plt.show()
[[707   0   1   1   2   2   2   0   5   0]
 [  0 797   2   0   0   1   1   1   1   0]
 [  1   0 663   1   0   0   1   9   1   0]
 [  0   1   8 690   0   8   1   7   5   3]
 [  0   2   0   0 631   0   3   0   0  10]
 [  2   1   2   1   0 655   7   0   2   1]
 [  1   0   1   0   3   1 689   0   3   0]
 [  0   0   5   0   1   0   0 708   0   3]
 [  1   3   6   2   2   8   2   1 630   3]
 [  0   1   0   1  12   2   1  10   2 659]]
             precision    recall  f1-score   support

        0.0       0.99      0.98      0.99       720
        1.0       0.99      0.99      0.99       803
        2.0       0.96      0.98      0.97       676
        3.0       0.99      0.95      0.97       723
        4.0       0.97      0.98      0.97       646
        5.0       0.97      0.98      0.97       671
        6.0       0.97      0.99      0.98       698
        7.0       0.96      0.99      0.97       717
        8.0       0.97      0.96      0.96       658
        9.0       0.97      0.96      0.96       688

avg / total       0.98      0.98      0.98      7000

../../../_images/notebooks_basic_weight_decay_notebook_12_1.png