Spatial Transformerという黒魔術
Dense Pose Transferという怪文書を読んでいたらSpatial Transformerで画像をワープさせていた。 最近流行りのTransformerではない。初出は2015年の同名の怪文書。
さらっと読んでみたけどホントか???という感じ。そんなうまい話ある???と思ったので動かしたの巻。
実装はここをめちゃくちゃ参考にした。あとは某ライブラリのtutorial。
必要なパッケージやら役に立つ関数定義。
import numpy as np import nnabla as nn import nnabla.functions as F import nnabla.parametric_functions as PF import nnabla.solvers as S from nnabla.monitor import tile_images from mnist_data import data_iterator_mnist import matplotlib.pyplot as plt def plot_stats(images): imshow_opt = dict(cmap='gray', interpolation='nearest') print("Num images:", images.shape[0]) print("Image shape:", images.shape[1:]) plt.imshow(tile_images(images), **imshow_opt) def categorical_error(pred, label): """ Compute categorical error given score vectors and labels as numpy.ndarray. """ pred_label = pred.argmax(1) return (pred_label != label.flat).mean()
肝心のネットワークだが、基本はMNISTの10クラス認識用の浅い(LeNetライクな)ネットワークの入力部分にSpatial Transformerを挿入しただけ。
def prediction_with_stn(image, test=False, with_stn=False): def stn(x): #x.need_grad = True # Spatial transformer localization-network xs = PF.convolution(x, 8, kernel=(7, 7), name='px1') xs = F.relu(F.max_pooling(xs, kernel=(2, 2), stride=(2, 2)), inplace=True) xs = PF.convolution(xs, 10, kernel=(5, 5), name='px2') xs = F.relu(F.max_pooling(xs, kernel=(2, 2), stride=(2, 2)), inplace=True) xs = F.reshape(xs, (-1, 10 * 3 * 3)) # Regressor for the 3 * 2 affine matrix theta = F.relu(PF.affine(xs, 32, name="theta_param1")) theta = PF.affine(theta, 2*3, w_init=np.zeros((32, 6)), b_init=np.array([1, 0, 0, 0, 1, 0]), # initial value REALLY matters. name="theta_param2") theta = F.reshape(theta, (-1, 2, 3)) grid = F.affine_grid(theta=theta, size=x.shape[2:]) x.need_grad = True # super important warped = F.warp_by_grid(x, grid, mode='linear') return warped image /= 255.0 if with_stn: warped = stn(image) else: warped = image warped.persistent = True c1 = PF.convolution(warped, 10, (5, 5), name='conv1') c1 = F.relu(F.max_pooling(c1, (2, 2)), inplace=True) c2 = F.dropout(PF.convolution(c1, 20, (5, 5), name='conv2')) c2 = F.relu(F.max_pooling(c2, (2, 2)), inplace=True) c3 = F.relu(PF.affine(c2, 50, name='fc3'), inplace=True) if not test: c3 = F.dropout(c3) c4 = PF.affine(F.dropout(c3), 10, name='fc4') return c4, warped
パラメータの名前はてきとうです。
main関数にあたる部分。グラフ定義+諸々。
from numpy.random import seed seed(0) # Get context. from nnabla.ext_utils import get_extension_context ctx = get_extension_context("cudnn") nn.set_default_context(ctx) # Create CNN network for both training and testing. batch_size = 16 # TRAIN # Create input variables. image = nn.Variable([batch_size, 1, 28, 28]) label = nn.Variable([batch_size, 1]) # Create prediction graph. pred, warped = prediction_with_stn(image, test=False, with_stn=True) pred.persistent = True # Create loss function. loss = F.mean(F.softmax_cross_entropy(pred, label)) # TEST # Create input variables. vimage = nn.Variable([batch_size, 1, 28, 28]) vlabel = nn.Variable([batch_size, 1]) # Create prediction graph. vpred, vwarped = prediction_with_stn(vimage, test=True, with_stn=True) # Create Solver. If training from checkpoint, load the info. solver = S.Sgd(lr=0.01) solver.set_parameters(nn.get_parameters()) # Initialize DataIterator for MNIST. from numpy.random import RandomState data = data_iterator_mnist(batch_size, True, rng=RandomState(223)) vdata = data_iterator_mnist(batch_size, False) iter_per_epoch = int(data.size / batch_size) viter_per_epoch = int(vdata.size / batch_size) # Create monitor. from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed monitor = Monitor("test") monitor_loss = MonitorSeries("Training loss", monitor, interval=iter_per_epoch) monitor_err = MonitorSeries("Training error", monitor, interval=iter_per_epoch) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=iter_per_epoch) monitor_verr = MonitorSeries("Test error", monitor, interval=1) # Training loop. for e in range(5): print(f"epoch {e}:") for i in range(iter_per_epoch): # Training forward image.d, label.d = data.next() solver.zero_grad() loss.forward() loss.backward() solver.update() err = categorical_error(pred.d, label.d) monitor_loss.add(e * iter_per_epoch + i, loss.d.copy()) monitor_err.add(e * iter_per_epoch + i, err) monitor_time.add(e * iter_per_epoch + i) ve = 0.0 for j in range(viter_per_epoch): vimage.d, vlabel.d = vdata.next() vpred.forward(clear_buffer=True) ve += categorical_error(vpred.d, vlabel.d) monitor_verr.add(e * iter_per_epoch + i, ve / viter_per_epoch)
とりあえず5Epoch学習させた。以下ログ。
epoch 0: 2021-02-23 04:33:43,057 [nnabla][INFO]: iter=3749 {Training loss}=1.069490671157837 2021-02-23 04:33:43,059 [nnabla][INFO]: iter=3749 {Training error}=0.3807833333333333 2021-02-23 04:33:43,063 [nnabla][INFO]: iter=3749 {Training time}=8.390070915222168[sec/3750iter] 8.390070915222168[sec] 2021-02-23 04:33:43,781 [nnabla][INFO]: iter=3749 {Test error}=0.0883 epoch 1: 2021-02-23 04:33:51,559 [nnabla][INFO]: iter=7499 {Training loss}=0.5968227386474609 2021-02-23 04:33:51,562 [nnabla][INFO]: iter=7499 {Training error}=0.21185 2021-02-23 04:33:51,563 [nnabla][INFO]: iter=7499 {Training time}=8.500881433486938[sec/3750iter] 16.890952348709106[sec] 2021-02-23 04:33:52,354 [nnabla][INFO]: iter=7499 {Test error}=0.0568 epoch 2: 2021-02-23 04:34:00,514 [nnabla][INFO]: iter=11249 {Training loss}=0.4672476351261139 2021-02-23 04:34:00,516 [nnabla][INFO]: iter=11249 {Training error}=0.16818333333333332 2021-02-23 04:34:00,518 [nnabla][INFO]: iter=11249 {Training time}=8.954166650772095[sec/3750iter] 25.8451189994812[sec] 2021-02-23 04:34:01,298 [nnabla][INFO]: iter=11249 {Test error}=0.0457 epoch 3: 2021-02-23 04:34:09,600 [nnabla][INFO]: iter=14999 {Training loss}=0.4128764271736145 2021-02-23 04:34:09,603 [nnabla][INFO]: iter=14999 {Training error}=0.14976666666666666 2021-02-23 04:34:09,605 [nnabla][INFO]: iter=14999 {Training time}=9.08706521987915[sec/3750iter] 34.93218421936035[sec] 2021-02-23 04:34:10,351 [nnabla][INFO]: iter=14999 {Test error}=0.0389 epoch 4: 2021-02-23 04:34:18,186 [nnabla][INFO]: iter=18749 {Training loss}=0.39158934354782104 2021-02-23 04:34:18,189 [nnabla][INFO]: iter=18749 {Training error}=0.14516666666666667 2021-02-23 04:34:18,194 [nnabla][INFO]: iter=18749 {Training time}=8.588963031768799[sec/3750iter] 43.52114725112915[sec] 2021-02-23 04:34:19,017 [nnabla][INFO]: iter=18749 {Test error}=0.029
ちゃんと学習している。 結果を可視化する。
ホントに変形しとる。マジか。 MNISTだと簡単な変形で済みそうだからうまくいくのもまぁわかるけど、もうちょい複雑な変形をしているように見えたDense Pose Transferでもうまくいくのは不思議。 もうちょいSpatial TransformerでWarpさせている生成系読むか。
というかGithubと連携するなりしてもうちょいコードみやすくしたい。