【matlab】SGD等バッチ処理実装において行列からランダムにn個ずつ要素を抽出する方法


経緯

最適化に関する研究に従事してる中で,バッチ処理が必要となるタイミングがありました.

あるN*Mの行列があり,その行列からランダムにK個の要素を取り出す.
それを全ての要素が取り出せるまでイテレーションを繰り返す.
したがってイテレーション回数はN*M/K回となる.

参考: (30) 配列要素をランダムにシャッフル
https://www.dogrow.net/octave/blog30/

順序

  1. 行列Aに対し,抽出する要素の数 Kとイテレーション回数を用意
  2. 行列サイズの分だけ行列のインデックスに相当する乱数配列を作成
  3. イテレーション毎に,乱数配列からK個分インデックスを取得し,行列Aの要素の抽出に利用

文章が究極に下手なので,次項目にて実例を見て頂きましょう.

実装

N=3, M=3の要素数9の行列Aを作成.
行列AからK=3個ずつ重複なくランダムに要素を繰り返し抽出する.

>> A = rand(3,3)

A =

    0.8147    0.9134    0.2785
    0.9058    0.6324    0.5469
    0.1270    0.0975    0.9575

>> K=3

イテレーション回数(N*M/K回)も初期化

>> iteration=3*3/K

iteration =

     3

matlabでは行列の要素を呼び出す際, A(3,2) と呼びだすこともできるし,単一インデックスで呼び出すことも可能.
今回,A(3,2)には「0.0975」が入っているが,単一インデックスで同様の要素を取り出すにはA(6)と呼び出す.

この性質を利用する.

予め行列サイズ(今回は3*3=9)に合わせてインデックスの乱数を作成し,そのインデックスに基づき任意の数の要素をランダムに取り出す.
乱数の作成にはrandperm関数を利用する.

参考: 配列インデックス付け (matlab公式)
https://jp.mathworks.com/help/matlab/math/array-indexing.html

>> batch = randperm(3*3)

batch =

     5     7     8     3     6     1     2     4     9

イテレーション毎に,抽出数 K 個分インデックスをbatch配列から取得.
取得したインデックスはindex配列に代入し,そのまま行列Aの要素呼び出しに利用.

>> for i = 1 : iteration
index = batch( K*(i-1)+1 : K*i )
A(index)
end

index =
     5     7     8

ans =
    0.6324    0.2785    0.5469

index =
     3     6     1

ans =
    0.1270    0.0975    0.8147

index =
     2     4     9

ans =
    0.9058    0.9134    0.9575

まとめ

バッチ処理のように,ランダムに要素を抽出する手法をまとめてみました.

matlabって公式がかなりhelp記事を書いてくれる一方で,かゆい処理について苦労する部分がありますよね...
何かお役に立てられればと記事を作成しました.

最後にすべてまとめたコードを記載します.

A = rand(3,3);
K=3;
iteration=3*3/K;

batch = randperm(3*3);

for i = 1 : iteration
  index = batch( K*(i-1)+1 : K*i )
  A(index)
end