圧倒的成長

日々の備忘録

Spatial Transformerという黒魔術

Dense Pose Transferという怪文書を読んでいたらSpatial Transformerで画像をワープさせていた。 最近流行りのTransformerではない。初出は2015年の同名の怪文書

さらっと読んでみたけどホントか???という感じ。そんなうまい話ある???と思ったので動かしたの巻。

実装はここをめちゃくちゃ参考にした。あとは某ライブラリのtutorial

必要なパッケージやら役に立つ関数定義。

import numpy as np

import nnabla as nn
import nnabla.functions as F
import nnabla.parametric_functions as PF
import nnabla.solvers as S
from nnabla.monitor import tile_images

from mnist_data import data_iterator_mnist
import matplotlib.pyplot as plt


def plot_stats(images):
    imshow_opt = dict(cmap='gray', interpolation='nearest')
    print("Num images:", images.shape[0])
    print("Image shape:", images.shape[1:])
    plt.imshow(tile_images(images), **imshow_opt)


def categorical_error(pred, label):
    """
    Compute categorical error given score vectors and labels as
    numpy.ndarray.
    """
    pred_label = pred.argmax(1)
    return (pred_label != label.flat).mean()

肝心のネットワークだが、基本はMNISTの10クラス認識用の浅い(LeNetライクな)ネットワークの入力部分にSpatial Transformerを挿入しただけ。

def prediction_with_stn(image, test=False, with_stn=False):

    def stn(x):
        #x.need_grad = True
        # Spatial transformer localization-network
        xs = PF.convolution(x, 8, kernel=(7, 7), name='px1')
        xs = F.relu(F.max_pooling(xs, kernel=(2, 2), stride=(2, 2)), inplace=True)
        xs = PF.convolution(xs, 10, kernel=(5, 5), name='px2')
        xs = F.relu(F.max_pooling(xs, kernel=(2, 2), stride=(2, 2)), inplace=True)
        xs = F.reshape(xs, (-1, 10 * 3 * 3))

        # Regressor for the 3 * 2 affine matrix
        theta = F.relu(PF.affine(xs, 32, name="theta_param1"))
        theta = PF.affine(theta, 2*3,
                          w_init=np.zeros((32, 6)),
                          b_init=np.array([1, 0, 0, 0, 1, 0]),  # initial value REALLY matters.
                          name="theta_param2")
        theta = F.reshape(theta, (-1, 2, 3))
        grid = F.affine_grid(theta=theta, size=x.shape[2:])
        x.need_grad = True  # super important
        warped = F.warp_by_grid(x, grid, mode='linear')
        return warped

    image /= 255.0

    if with_stn:
        warped = stn(image)
    else:
        warped = image
    warped.persistent = True

    c1 = PF.convolution(warped, 10, (5, 5), name='conv1')
    c1 = F.relu(F.max_pooling(c1, (2, 2)), inplace=True)

    c2 = F.dropout(PF.convolution(c1, 20, (5, 5), name='conv2'))
    c2 = F.relu(F.max_pooling(c2, (2, 2)), inplace=True)

    c3 = F.relu(PF.affine(c2, 50, name='fc3'), inplace=True)

    if not test:
        c3 = F.dropout(c3)
    c4 = PF.affine(F.dropout(c3), 10, name='fc4')
    return c4, warped

パラメータの名前はてきとうです。

main関数にあたる部分。グラフ定義+諸々。

from numpy.random import seed
seed(0)

# Get context.
from nnabla.ext_utils import get_extension_context
ctx = get_extension_context("cudnn")
nn.set_default_context(ctx)

# Create CNN network for both training and testing.

batch_size = 16

# TRAIN
# Create input variables.
image = nn.Variable([batch_size, 1, 28, 28])
label = nn.Variable([batch_size, 1])
# Create prediction graph.
pred, warped = prediction_with_stn(image, test=False, with_stn=True)
pred.persistent = True
# Create loss function.
loss = F.mean(F.softmax_cross_entropy(pred, label))

# TEST
# Create input variables.
vimage = nn.Variable([batch_size, 1, 28, 28])
vlabel = nn.Variable([batch_size, 1])
# Create prediction graph.
vpred, vwarped = prediction_with_stn(vimage, test=True, with_stn=True)

# Create Solver. If training from checkpoint, load the info.
solver = S.Sgd(lr=0.01)
solver.set_parameters(nn.get_parameters())

# Initialize DataIterator for MNIST.
from numpy.random import RandomState
data = data_iterator_mnist(batch_size, True, rng=RandomState(223))
vdata = data_iterator_mnist(batch_size, False)

iter_per_epoch = int(data.size / batch_size)
viter_per_epoch = int(vdata.size / batch_size)

# Create monitor.
from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
monitor = Monitor("test")
monitor_loss = MonitorSeries("Training loss", monitor, interval=iter_per_epoch)
monitor_err = MonitorSeries("Training error", monitor, interval=iter_per_epoch)
monitor_time = MonitorTimeElapsed("Training time", monitor, interval=iter_per_epoch)
monitor_verr = MonitorSeries("Test error", monitor, interval=1)

# Training loop.
for e in range(5):
    print(f"epoch {e}:")
    for i in range(iter_per_epoch):
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward()
        loss.backward()
        solver.update()
        err = categorical_error(pred.d, label.d)
        monitor_loss.add(e * iter_per_epoch + i, loss.d.copy())
        monitor_err.add(e * iter_per_epoch + i, err)
        monitor_time.add(e * iter_per_epoch + i)

    ve = 0.0
    for j in range(viter_per_epoch):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    monitor_verr.add(e * iter_per_epoch + i, ve / viter_per_epoch)

とりあえず5Epoch学習させた。以下ログ。

epoch 0:
2021-02-23 04:33:43,057 [nnabla][INFO]: iter=3749 {Training loss}=1.069490671157837
2021-02-23 04:33:43,059 [nnabla][INFO]: iter=3749 {Training error}=0.3807833333333333
2021-02-23 04:33:43,063 [nnabla][INFO]: iter=3749 {Training time}=8.390070915222168[sec/3750iter] 8.390070915222168[sec]
2021-02-23 04:33:43,781 [nnabla][INFO]: iter=3749 {Test error}=0.0883
epoch 1:
2021-02-23 04:33:51,559 [nnabla][INFO]: iter=7499 {Training loss}=0.5968227386474609
2021-02-23 04:33:51,562 [nnabla][INFO]: iter=7499 {Training error}=0.21185
2021-02-23 04:33:51,563 [nnabla][INFO]: iter=7499 {Training time}=8.500881433486938[sec/3750iter] 16.890952348709106[sec]
2021-02-23 04:33:52,354 [nnabla][INFO]: iter=7499 {Test error}=0.0568
epoch 2:
2021-02-23 04:34:00,514 [nnabla][INFO]: iter=11249 {Training loss}=0.4672476351261139
2021-02-23 04:34:00,516 [nnabla][INFO]: iter=11249 {Training error}=0.16818333333333332
2021-02-23 04:34:00,518 [nnabla][INFO]: iter=11249 {Training time}=8.954166650772095[sec/3750iter] 25.8451189994812[sec]
2021-02-23 04:34:01,298 [nnabla][INFO]: iter=11249 {Test error}=0.0457
epoch 3:
2021-02-23 04:34:09,600 [nnabla][INFO]: iter=14999 {Training loss}=0.4128764271736145
2021-02-23 04:34:09,603 [nnabla][INFO]: iter=14999 {Training error}=0.14976666666666666
2021-02-23 04:34:09,605 [nnabla][INFO]: iter=14999 {Training time}=9.08706521987915[sec/3750iter] 34.93218421936035[sec]
2021-02-23 04:34:10,351 [nnabla][INFO]: iter=14999 {Test error}=0.0389
epoch 4:
2021-02-23 04:34:18,186 [nnabla][INFO]: iter=18749 {Training loss}=0.39158934354782104
2021-02-23 04:34:18,189 [nnabla][INFO]: iter=18749 {Training error}=0.14516666666666667
2021-02-23 04:34:18,194 [nnabla][INFO]: iter=18749 {Training time}=8.588963031768799[sec/3750iter] 43.52114725112915[sec]
2021-02-23 04:34:19,017 [nnabla][INFO]: iter=18749 {Test error}=0.029

ちゃんと学習している。 結果を可視化する。

f:id:destroy_linux:20210223135239p:plain
入力画像サンプル(=Spatial Transformerによる変形前)

f:id:destroy_linux:20210223135336p:plain
Spatial Transformerによって変形された入力

ホントに変形しとる。マジか。 MNISTだと簡単な変形で済みそうだからうまくいくのもまぁわかるけど、もうちょい複雑な変形をしているように見えたDense Pose Transferでもうまくいくのは不思議。 もうちょいSpatial TransformerでWarpさせている生成系読むか。

というかGithubと連携するなりしてもうちょいコードみやすくしたい。