chainer 1.11.0以降のmnistを解説


はじめに

chainerが1.11.0になってから結構変わっていたので、自分なりの理解を書きます。
できるだけ、初めてpythonとchainerをやってみる人にも分かるようします(つもりです)。

コードはここ
サンプルの中のtrain_mnist.pyというファイルです。

MNIST

mnistとは28x28のサイズの数字が書かれた画像のデータセットです。
機械学習で入門用としてよく使われるものです。

ネットワーク

class MLP(chainer.Chain):
    def __init__(self, n_in, n_units, n_out):
        super(MLP, self).__init__(
            l1=L.Linear(n_in, n_units),
            l2=L.Linear(n_units, n_units), 
            l3=L.Linear(n_units, n_out), 
        )

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

ネットワークの定義では__init__の方で、使う層を定義する。
今回は、
* 入力がn_in、出力がn_unitsの全結合層 l1
* 入力がn_units、出力がn_unitsの全結合層 l2
* 入力がn_units、出力がn_outsの全結合層 l3

__call__では、具体的なネットワークを記述する。
今回は、l1とl2の出力にreluという活性化関数を使っている。

Parser

一応parserのことも

    parser = argparse.ArgumentParser(description='Chainer example: MNIST')
    parser.add_argument('--batchsize', '-b', type=int, default=100,
                help='Number of images in each mini-batch')
    parser.add_argument('--epoch', '-e', type=int, default=20,
                help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out', '-o', default='result',
                help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                help='Resume the training from snapshot')
    parser.add_argument('--unit', '-u', type=int, default=1000,
                help='Number of units')
    args = parser.parse_args()

    print('GPU: {}'.format(args.gpu))
    print('# unit: {}'.format(args.unit))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('')

parserはpythonをコマンドで実行するときにパラメーターを設定しやすくしてくれる便利なやつです。
ターミナルで例えば以下のように実行すると

> $ python train_mnist.py -g 0 -u 100
GPU: 0
# unit: 100
# Minibatch-size: 100
# epoch: 20

と表示されます。
指定してないepochなどは、defaultで初期化されている値になります。
自分で追加したいときは

add_argument('後で呼ぶための名前', '-ターミナルでの指定方法', 数字ならtype=int, 指定のなかった場合のdefault値)

のような形で利用できます。

データの初期化

chainerではtrainデータと、testデータを用意します。

train, test = chainer.datasets.get_mnist()

これはmnistで使われるデータを取ってきてtrainとtestに入れてるだけです。
なかがどんなカタチになっているかというと、一つの行(train[0])に
[[.234809284, .324039284, .34809382 …. .04843098], 3]
というように、左に入力値と右にその答え(ラベル値)がセットで入っています。
また、chainerではtrainで学習して、testで試してみて正解率を見ていく感じになります。

イテレータ

従来では自分でfor分を用意して何回も回して学習させてとやっていたのですが、1.11.0からは上の train のようにデータを入れておいて、これを使いますと言ってあげればfor分を書く必要はありません。

train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
                                                 repeat=False, shuffle=False)

もうこれでいいらしい。おまじない感ある

Trainer

trainerというものが追加されてこれがもうほぼほぼ勝手にいろいろやってくれるそう。
問題集と答えを家庭教師に渡して、子供をよろしくお願いします。的な
自分で勉強を教えてたのを、家庭教師に任せるイメージ(合っているかはわからない)

まず、trainerを設定する。

updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
trainer = training.Trainer(updater, (args.epoch, 'epoch'),

この train_iter (問題集)を使って、この optimizer (勉強方法)で最適化してもらって、
それを _epoch _ (何周)回してください。

以下については必ずしも必要なわけではないものもある。

trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))
    # これはいる。test_iterを使ってepochごとに評価してる(と思う)
trainer.extend(extensions.dump_graph('main/loss'))
    # ネットワークの形をグラフで表示できるようにdot形式で保存する。
trainer.extend(extensions.snapshot(), trigger=(args.epoch, 'epoch'))
    # epochごとのtrainerの情報を保存する。それを読み込んで、途中から再開などができる。これけすと結構早くなったりした?
trainer.extend(extensions.LogReport())
    # epochごとにlogをだす
trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy', 'validation/main/accuracy']))
    # logで出す情報を指定する。
trainer.extend(extensions.ProgressBar())
    # 今全体と、epochごとでどのぐらい進んでいるかを教えてくれる。

trainer.run()
    # trainerをいろいろ設定した後、これをやって実際に実行する。これは必須

main/lossは答えとの差の大きさ。
mian/accuracyは正解率。
validation/main/accuracyが何を指しているかは、よくわかりません。(誰かコメントしていただけると...)

なんでここ説明して、あそこ説明しないのとかになると思うけどそれはまだ良くわかって無いからだったり

実際にどう弄ったかなどは、まだ上げる予定です。