MATLABでNNを作る話の備忘録 その1


無とは

無については、こちら
本編とは正直そんなに関係もない。
今回は無Keras第1回でやっている簡単な2値分類問題をMATLABで実装していこうと思っている。

概要

この度必要にかられてMATLABでニューラルネットワーク(NN)を作る事になったので、以前自分がKerasを勉強した時に見た「無から〜」シリーズを参考にしながら、備忘録としてまとめていきたい。

自分用の備忘録だから、冒頭みたいな遊びを入れる意味は特にないと言っていい。

なぜMATLABか

諸般の事情によるものです。MATLABの経験はありません。
以前はPythonでNNを書いたり上記のようにKerasを触ってみたりしていました。
あと僕は情報系じゃないしコーディング力も低いので、コードがクソとかそういうのは許してほしいと思っている。備忘録なので。

参考にしたもの

準備

必要なもの

MATLABと、Neural Network Toolboxが必要です。

今回の問題

今回やるのは、

適当な5次元の配列を入力したら、その和が2.5より大きいときに1と教えてくれるニューラルネットを作ろうという話。

という話。
ぶっちゃけ、カニ雌雄判別問題とほとんど同じである。
練習を兼ねて、データセットをランダムに生成するところからやってみた、って話。

データ生成

まずrand関数で5*250の行列を生成する。
次に各データの5個の要素の和が2.5より小さいか、2.5以上かでラベリングをする。2.5より小さいなら(1,0)、2.5以上なら(0,1)になる。

%%データ生成
%5*250のデータ行列を生成
data = rand(5,250);
%2*250のラベル配列を生成。
label = [sum(data)<2.5 ; sum(data)>=2.5];

特にラベル付の方はもっといいやり方ありそうだな…。

ニューラルネットワークを生成する

%%ニューラルネットを作る
%重み行列の初期値を生成するための乱数は固定になっている
setdemorandstream(491218382)
net = patternnet(20);
view(net)

patternnetというクラスを使う。「Pattern Recognition Neural Network」と説明には書いてある。生成されるのは、2層(隠れ層が1層)のNNで、出力層の活性化関数はsoftmax関数かな? 隠れ層の活性化関数が何なのかわからないけど、viewを使って表示させてみると非線形関数の形っぽいのが書いてある。
中身がなんにもわからないものを使うのはどうなんだって感じだけど、今回はとにかくNNを走らせることが目標なので。
patternnetの引数で与えたのは隠れ層のノード数。他にtrainFcn、つまりどういうふうに誤差を最小化していくかのアルゴリズムを選べるっぽい。

実際に積み上げて層を作るのはここに書いてある気がするので今度勉強します。

viewを使うとネットワークの構造が可視化される。各層の下の数字はノード数だが、見てわかるように入出力のノード数は0になっている。
これは訓練時にデータを与えてやると勝手に判別してくれる。すごーい!

訓練する

%%訓練する
%訓練データから勝手に入出力を読んでくれる
[net,tr] = train(net,data,label);
nntraintool
plotperform(tr)

trainを使うと、データを訓練用、交差検定用(?)、テスト用に分けてくれる。trには訓練過程が保存されるっぽい?

テストする

test_data = data(:,tr.testInd);
test_label= label(:,tr.testInd);
test_result = net(test_data);
testIndices = vec2ind(test_result)

plotconfusion(test_label,test_result)

[c,cm] = confusion(test_label,test_result)

fprintf('Percentage Correct Classification   : %f%%\n', 100*(1-c));
fprintf('Percentage Incorrect Classification : %f%%\n', 100*c);

netにデータを与えてやると、ネットワークを通した後の出力を返してくれる。vec2indは、one-hotなベクトルをラベルになおしてくれる。
confusion関数は、正解と出力を比較して、confusion matrixを出力してくれる。cは誤り率。

自分が実行したときの実行結果は、こんな感じ。

cm =

    17     0
     1    20

Percentage Correct Classification   : 97.368421%
Percentage Incorrect Classification : 2.631579%

正解率97%。めっちゃ正解するNN。

まとめ

なんとか2値分類をするNNをMATLABで動かすことができた。今回はあくまで動かすのが目的だったので、これでよしとする。
次は自分でNNを組みたい。フルスクラッチではないが、中身の分からないネットワークを使うのもいかがなものかって感じがするので。

疲れた。