chainerのconnectionをいじって新しい層を作る(1)


環境

GPU GTX1070
ubuntu 14.04
chainer 1.14.0
など

はじめに

chainerで最新のモデルを実装する際には、links/connectionやfunctions/connectionをいじる必要がある。

そこで最も単純なlinear.pyをいじって、新しい層を作ってみる。前回はlinear.pyの中身を確認した。
http://qiita.com/masataka46/items/d66997ac94ec7aa3bcb4

今回はchainer/functions/connection/linear.pyのforward関数をいじって順伝播を改良する。

改良モデルの概要

元々の全結合3層を以下の図のように改良する。

2層目だけを改良する。この2層目は具体的に以下のように入力側に関して重みを共有する。

この演算処理は以下の図のようになる。

Wは入力側n個で重みを共有するので、W(out_size, in_size / n)となる。この重みを1度の行列積で計算できるよう、in_size / n側をn倍し、in_sizeとする。

この重みと入力側からのデータxとの行列積を求めると、y(batch_size, out_size)が出力される。これにより重みのパラメーター数が減るので、演算は速くなるだろう。そして、性能が若干低下するだろう。また今回、計算を簡略化するためbiasは使わないでおく。

tain_mnist.pyを修正する

train_mnist.pyも若干変わってくるので修正する。

common_num = 10
out_units = 900
    #chnged model
    def __init__(self, n_in, n_units, n_out):
        super(MLP, self).__init__(
            l1=L.Linear(n_in, n_units),  # first layer
            l2=linear_link.Linear(n_units / common_num, out_units, nobias=True),  # second layer
            l3=L.Linear(out_units, n_out),  # output layer
        )

グローバル変数で定義したcommon_numが、共有する数。また特に意味は無いが3層目のunit数を900に変えた。

function下のlinear.pyを修正する

chainer/functions/connection/linear.pyのLinearFunctionクラス内forward関数を以下のように修正する。

import cupy
    #modified forward function
    def forward(self, inputs):
        x = _as_mat(inputs[0])
        W = inputs[1]

        #modify to original model
        W_tile = cupy.tile(W.T, (common_num, 1)).astype(W.dtype, copy=False)
        y = x.dot(W_tile).astype(x.dtype, copy=False)

        if len(inputs) == 3:
            b = inputs[2]
            y += b

        return y,

Wをcommon_num倍する際に、cupy(numpy)のtile()を使った。GPU使うのを想定してcupyをimportしているが、使わないならnumpyに変える必要がある。

またcheck_type_forward関数があるとエラーが出るので、コメントアウトする。

    '''
    def check_type_forward(self, in_types):
        n_in = in_types.size()
        type_check.expect(2 <= n_in, n_in <= 3)
        x_type, w_type = in_types[:2]

        type_check.expect(
            x_type.dtype.kind == 'f',
            w_type.dtype.kind == 'f',
            x_type.ndim >= 2,
            w_type.ndim == 2,
            type_check.prod(x_type.shape[1:]) == w_type.shape[1],
        )
        if n_in.eval() == 3:
            b_type = in_types[2]
            type_check.expect(
                b_type.dtype == x_type.dtype,
                b_type.ndim == 1,
                b_type.shape[0] == w_type.shape[0],
            )
    '''

この関数がお節介にもxとWの大きさが対応しているか調べてるみたい。今回、明らかにWだけ小さくしてるので、これが機能するとエラーとなる。