MATLABのNeural Network Toolboxで自作のレイヤを定義するための知見まとめ


この記事について

MATLABのNeural Network Toolboxを利用してDeep Learningを行う際、自作のレイヤを使いたいときに必要な知見をいくつかまとめます。
まだ調べている途中でもあるので、今後まだ増えていくと思います。
主に自分用メモなので、自分の使いたい用途のみに絞った知見がまとめられてます。そこはご了承ください。

自作回帰出力層の定義

基本的には以下のテンプレートに沿って作ります。下記のドキュメンテーションからの転載です。
https://jp.mathworks.com/help/nnet/ug/define-regression-output-layer.html

classdef myRegressionLayer < nnet.layer.RegressionLayer

    properties
        % (Optional) Layer properties

        % Layer properties go here
    end

    methods
        function layer = myRegressionLayer()           
            % (Optional) Create a myRegressionLayer

            % Layer constructor function goes here
        end

        function loss = forwardLoss(layer, Y, T)
            % Return the loss between the predictions Y and the 
            % training targets T
            %
            % Inputs:
            %         layer - Output layer
            %         Y     – Predictions made by network
            %         T     – Training targets
            %
            % Output:
            %         loss  - Loss between Y and T

            % Layer forward loss function goes here
        end

        function dLdX = backwardLoss(layer, Y, T)
            % Backward propagate the derivative of the loss function
            %
            % Inputs:
            %         layer - Output layer
            %         Y     – Predictions made by network
            %         T     – Training targets
            %
            % Output:
            %         dLdX  - Derivative of the loss with respect to the input X        

            % Layer backward loss function goes here
        end
    end
end

このテンプレートに沿って自分のやりたいことを書いていけばいいのですが、forwardLossやbackwardLossの引数として与えられるYとTのサイズがわからないと自作のレイヤの定義するのが難しいので、それを調べていきます。

fullyConnectedLayerの出力サイズ

最終的な予測値の出力層として全結合層を考える場合、fullyConnectedLayerの出力のサイズがわからないと自作回帰出力層を作れないので、それを調べました。
まず、サイズを調べる出力層を定義します。

testLayer.m
classdef testLayer < nnet.layer.RegressionLayer 
    methods
        function layer = testLayer()           
        end

        function loss = forwardLoss(layer, Y, T)
            size(Y)
            size(T)
            loss = gpuArray(single(0));
        end

        function dLdX = backwardLoss(layer, Y, T)
            dLdX = gpuArray(single(zeros(size(Y))));
        end
    end
end

続いて、次のように実行スクリプトを定義します。

run_test.m
%% 初期化
clear
close

%% 学習データの定義
x_in = rand(10, 1, 1, 6); % イメージ入力を想定
y_tr = rand(6, 5);

%% 層構造の定義
layers = [
    imageInputLayer([10 1 1], 'Name', 'Input')
    fullyConnectedLayer(5, 'Name', 'Output')
    testLayer
    ];
layers(end).Name = 'Test';

%% オプションの定義
options = trainingOptions(...
            'sgdm',...
            'InitialLearnRate', 0.001, ...
            'MiniBatchSize', 3, ...
            'MaxEpochs', 1);

%% 実行
net = trainNetwork(x_in, y_tr, layers, options);

run_test.mを実行すると、

ans =

     1     1     5    3


ans =

     1     1     5    3

が出力されます。この結果から、fullyConnectedLayerの出力サイズは、
1 x 1 x OutputSize x MinibatchSize
となると考えられます。

depthConcatenationLayerの動作

DAGネットワークを想定して、depthConcatenationLayerを利用して2つの層を連結したときのサイズについて調べます。
先ほどのrun_test.mを次のように変更します。

run_test.m
%% 初期化
clear
close

%% 学習データの定義
x_in = rand(10, 1, 1, 6); % イメージ入力を想定
y_tr = rand(6, 7);

%% 層構造の定義
layers = [
    imageInputLayer([10 1 1], 'Normalization', 'none', 'Name', 'Input')
    fullyConnectedLayer(2, 'Name', 'Layer1')
    reluLayer('Name', 'ReLU1')
    fullyConnectedLayer(5, 'Name', 'Output')
    depthConcatenationLayer(2, 'Name', 'Concate')
    testLayer
    ];
layers(end).Name = 'Test';
% DAGネットワークの定義
lGraph = layerGraph(layers);
lGraph = connectLayers(lGraph, 'Layer1', 'Concate/in2');
plot(lGraph)

%% オプションの定義
options = trainingOptions(...
            'sgdm',...
            'InitialLearnRate', 0.001, ...
            'MiniBatchSize', 3, ...
            'MaxEpochs', 1);

%% 実行
net = trainNetwork(x_in, y_tr, lGraph, options);

これを実行すると、

ans =

     1     1     7     3


ans =

     1     1     7     3

が出力されます。この結果から、depthConcatenationLayerの出力サイズは
1 x 1 x (OutputSize1 + OutputSize2) x MinibatchSize
となると考えられます。