Alink漫談(七):トレーニングデータセットとテストデータセットをどのように区分するか
12541 ワード
Alink漫談(七):トレーニングデータセットとテストデータセットをどのように区分するか
目次
0 x 00サマリ
AlinkはアリババがリアルタイムコンピューティングエンジンFlinkに基づいて開発した次世代機械学習アルゴリズムプラットフォームであり、業界初のバッチアルゴリズム、フローアルゴリズムを同時にサポートする機械学習プラットフォームである.ここでは、Alinkがトレーニングデータセットとテストデータセットをどのように区分するかを示します.
0 x 01トレーニングデータセットとテストデータセット
にぶんほう
一般的に予測分析を行う場合、データは2つの大部分に分けられます.一部はトレーニングデータであり、モデルを構築するために使用され、一部はテストデータであり、モデルを検証するために使用される.
さんぶんほう
しかし、モデルの構築過程でもモデル/補助モデルの構築を検証する必要がある場合があります.この場合、訓練データは2つの部分に分けられます.1)訓練データ;2)検証データ(Validation Data).この場合、データは3つの部分に分けられます.
Training setは、ANNの重み値など、モデルを訓練したり、モデルパラメータを決定したりするために使用されます.
Validation setはモデル選択(model selection)を行うために使用され、すなわちANNの構造のようなモデルの最終的な最適化と決定を行う.
Test setは純粋に訓練されたモデルの普及能力をテストするためである.もちろんtest setはモデルの正確性を保証することはできません.彼は似たようなデータがこのモデルで似たような結果を出すと言っただけです.
実際の応用
実際の応用では、一般的にデータセットをtraining setとtest setの2つに分類するだけで、多くの文章はvalidation setには関連していません.私たちもここでは触れません.みんなでよく使うスカリーンのtrain_test_split関数は、マトリクスをランダムにトレーニングサブセットとテストサブセットに分割し、分割されたトレーニングセットテストセットサンプルとトレーニングセットテストセットラベルを返します.
0 x 02 Alinkサンプルコード
まず、サンプルコードを示し、次に深く分析します.
public class SplitExample {
public static void main(String[] args) throws Exception {
String url = "iris.csv";
String schema = "sepal_length double, sepal_width double, petal_length double, petal_width double, category string";
//
BatchOperator data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schema);
SplitBatchOp spliter = new SplitBatchOp().setFraction(0.8);
spliter.linkFrom(data);
BatchOperator trainData = spliter;
BatchOperator testData = spliter.getSideOutput(0);
//
CsvSourceStreamOp dataS = new CsvSourceStreamOp().setFilePath(url).setSchemaStr(schema);
SplitStreamOp spliterS = new SplitStreamOp().setFraction(0.4);
spliterS.linkFrom(dataS);
StreamOperator train_data = spliterS;
StreamOperator test_data = spliterS.getSideOutput(0);
}
}
0 x 03バッチ
SplitBatchOpはバッチを分割する主なクラスであり,DAGを具体的に構築する作業はlinkFromで行われる.
全体的な考え方は簡単です.
numTarget = totCount * fraction
task_n_count * fraction
totSelect = task_1_count * fraction + task_2_count * fraction + ... task_n_count * fraction
numTarget - totSelect
をあるtaskに追加することをランダムに決定した.3.1得られた記録数
データを分割するには、まずデータセットのレコード数を知る必要があります.例えばこのDataSetの記録は1万個?それとも10万個ですか.データセットが大きい可能性があるため、このステップではパラレル処理、すなわちデータをパーティション化し、mapPartition操作によって各パーティション上の要素の数を得ることもできます.
DataSet> countsPerPartition = DataSetUtils.countElementsPerPartition(rows); // task
DataSet numPickedPerPartition = countsPerPartition
.mapPartition(new CountInPartition(fraction)) //
.setParallelism(1)
.name("decide_count_of_each_partition");
各パーティションはtaskに対応するので,taskごとのレコード数を取得したと考えてもよい.
具体的な仕事はDataSetUtilsです.countElementsPerPartitionで完成しました.戻りタイプは、例えば3番taskが30個のレコードを持っている.
public static DataSet> countElementsPerPartition(DataSet input) {
return input.mapPartition(new RichMapPartitionFunction>() {
@Override
public void mapPartition(Iterable values, Collector> out) throws Exception {
long counter = 0;
for (T value : values) {
counter++; // task
}
out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter));
}
});
}
総数を計算する作業は、実は次の段階の演算子で行われます.
3.2ランダム選択記録
次の仕事は主にCountInPartitionです.mapPartitionは、taskごとにどれだけのレコードを選択するかをランダムに決定する役割を果たします.
この時点で並行する必要はありませんので、
.setParallelism(1)
3.2.1総記録数を得る
各パーティションレコード数を得た後,taskごとのレコード数を巡り,合計レコード数totCount(上から下へ計算した総数)を蓄積した.
public void mapPartition(Iterable> values, Collector out) throws Exception {
long totCount = 0L;
List> buffer = new ArrayList<>();
for (Tuple2 value : values) { //
totCount += value.f1; //f1 Long
buffer.add(value);
}
...
// 。
}
3.2.2 taskごとに選択するレコード数を決定する
そしてCountInPartition.mapPartition関数ではtaskごとに選択されるレコード数がランダムに決定されます.mapPartitionのパラメータIterable>valuesは、前の段階の結果です.元祖です.
これらの元祖を結びつけてbufferというリストに記録します.
buffer = {ArrayList@8972} size = 4
0 = {Tuple2@8975} "(3,38)" // 3 task, partition 38 。
1 = {Tuple2@8976} "(2,0)"
2 = {Tuple2@8977} "(0,38)"
3 = {Tuple2@8978} "(1,74)"
システムのtask数はbufferサイズです.
int npart = buffer.size(); // num tasks
そして、「記録総数」から「ランダムトレーニングデータの個数numTarget」を算出する.例えば総数は1万で、ランダムに20%割り当てるべきで、numTargetは2千であるべきです.この数字は後で使います.
long numTarget = Math.round((totCount * fraction));
各taskの記録数、例えば上のbufferの38、0、38、または74がeachCountに記録される.
for (Tuple2 value : buffer) {
eachCount[value.f0] = value.f1;
}
各taskでランダムに選択された訓練記録数をeachSelectに記録した.各taskが現在「記録数字*fraction」です.例えば3番taskの記録数は38個で、20%を選ぶべきで、38*20%=8個です.
そしてこれらのtask自身の「ランダムトレーニング記録数」をさらに加算してtotSelect(下から計算した総数)を得る.
long totSelect = 0L;
for (int i = 0; i < npart; i++) {
eachSelect[i] = Math.round(Math.floor(eachCount[i] * fraction));
totSelect += eachSelect[i];
}
このときtotSelectと以前計算したnumTargetには具体的な微細な違いがあり、理論的な数字ですが、上から下へ計算するのと下から計算するのでは、結果が異なる可能性があります.下を見るとわかります.
numTarget = all count * fraction
totSelect = task_1_count * fraction + task_2_count * fraction + ...
そこで、次のステップでこの微細な出入りを処理すると、remainが得られます.これは、「全体的に算出されたランダム数」numTargetと「すべてのtaskから選択されたランダムな訓練記録数の蓄積」totSelectの差です.
if (totSelect < numTarget) {
long remain = numTarget - totSelect;
remain = Math.min(remain, totCount - totSelect);
ちょうど個数が等しい場合は、正常に割り当てられます.
if (remain == totCount - totSelect) {
数が異なる場合は、eachSelect配列の任意のレコードに「より多くのremain」を追加することをランダムに決定します.
for (int i = 0; i < Math.min(remain, npart); i++) {
int taskId = shuffle.get(i);
while (eachSelect[taskId] >= eachCount[taskId]) {
taskId = (taskId + 1) % npart;
}
eachSelect[taskId]++;
}
最後にすべての情報を与えます
long[] statistics = new long[npart * 2];
for (int i = 0; i < npart; i++) {
statistics[i] = eachCount[i];
statistics[i + npart] = eachSelect[i];
}
out.collect(statistics);
// 4 , eachCount, eachSelect
statistics = {long[8]@9003}
0 = 38 //eachCount
1 = 38
2 = 36
3 = 38
4 = 31 //eachSelect
5 = 31
6 = 28
7 = 30
これらの情報はブロードキャスト変数として格納されており,すぐに以下に用いられる.
.withBroadcastSet(numPickedPerPartition, "counts")
3.2.3 taskごとにレコードを選択する
CountInPartition.PickInPartition関数では、taskごとにランダムにレコードが選択されます.
まずtask数と以前に格納されたブロードキャスト変数(すなわち、以前に格納されたばかり)が得られる.
int npart = getRuntimeContext().getNumberOfParallelSubtasks();
List bc = getRuntimeContext().getBroadcastVariable("counts");
countとselectを分離します.
long[] eachCount = Arrays.copyOfRange(bc.get(0), 0, npart);
long[] eachSelect = Arrays.copyOfRange(bc.get(0), npart, npart * 2);
合計task数を得る
int taskId = getRuntimeContext().getIndexOfThisSubtask();
自分のtask対応のcountを得る
long count = eachCount[taskId];
long select = eachSelect[taskId];
本task対応のレコードを追加し、ランダムにシャッフルして順番を狂わせる
for (int i = 0; i < count; i++) {
shuffle.add(i); // count
}
Collections.shuffle(shuffle, new Random(taskId)); //
// suffle
shuffle = {ArrayList@8987} size = 38
0 = {Integer@8994} 17
1 = {Integer@8995} 8
2 = {Integer@8996} 33
3 = {Integer@8997} 34
4 = {Integer@8998} 20
5 = {Integer@8999} 0
6 = {Integer@9000} 26
7 = {Integer@9001} 27
8 = {Integer@9002} 23
9 = {Integer@9003} 28
10 = {Integer@9004} 9
11 = {Integer@9005} 16
12 = {Integer@9006} 13
13 = {Integer@9007} 2
14 = {Integer@9008} 5
15 = {Integer@9009} 31
16 = {Integer@9010} 15
17 = {Integer@9011} 22
18 = {Integer@9012} 18
19 = {Integer@9013} 35
20 = {Integer@9014} 36
21 = {Integer@9015} 12
22 = {Integer@9016} 7
23 = {Integer@9017} 21
24 = {Integer@9018} 14
25 = {Integer@9019} 1
26 = {Integer@9020} 10
27 = {Integer@9021} 30
28 = {Integer@9022} 29
29 = {Integer@9023} 19
30 = {Integer@9024} 25
31 = {Integer@9025} 32
32 = {Integer@9026} 37
33 = {Integer@9027} 4
34 = {Integer@9028} 11
35 = {Integer@9029} 6
36 = {Integer@9030} 3
37 = {Integer@9031} 24
ランダムに選択し、選択したものを並べ替えます
for (int i = 0; i < select; i++) {
selected[i] = shuffle.get(i); // select , suffle
}
Arrays.sort(selected); //
// selected , 30
selected = {int[30]@8991}
0 = 0
1 = 1
2 = 2
3 = 5
4 = 7
5 = 8
6 = 9
7 = 10
8 = 12
9 = 13
10 = 14
11 = 15
12 = 16
13 = 17
14 = 18
15 = 19
16 = 20
17 = 21
18 = 22
19 = 23
20 = 26
21 = 27
22 = 28
23 = 29
24 = 30
25 = 31
26 = 33
27 = 34
28 = 35
29 = 36
選択したデータの送信
if (numEmits < selected.length && iRow == selected[numEmits]) {
out.collect(row);
numEmits++;
}
3.3トレーニングデータセットとテストデータセットの設定
outputはトレーニングデータセット、SideOutputはテストデータセットです.この2つのデータセットはいずれもAlink内部でTableタイプであるため、SQL演算子
minusAll
を直接使用して分割を完了する.this.setOutput(out, in.getSchema());
this.setSideOutputTables(new Table[]{in.getOutputTable().minusAll(this.getOutputTable())});
0 x 04フロー処理
訓練はSplitStreamOpクラスで行われ,linkFromによってモデルの構築が完了した.
ストリーム処理はSplitStreamとSelectTransformationの2つのクラスに依存して分割ストリームを完了する.具体的には物理的な操作は確立されていないが,上流演算子が下流演算子とどのように関連し,どのように記録を選択するかに影響を及ぼしている.
SplitStream splited = in.getDataStream().split(new RandomSelectorOp(getFraction()));
まず,出力時にどのストリームを選択するかをRandomSelectorOpでランダムに決定する.ここには「a」と「b」の2つの名前が勝手につけられていることがわかります.
class RandomSelectorOp implements OutputSelector {
private double fraction;
private Random random = null;
@Override
public Iterable select(Row value) {
if (null == random) {
random = new Random(System.currentTimeMillis());
}
List output = new ArrayList (1);
output.add((random.nextDouble() < fraction ? "a" : "b")); // ,
return output;
}
}
次に,その2つのランダムに生成されたストリームを得た.
DataStream partA = splited.select("a");
DataStream partB = splited.select("b");
最後にこの2つのストリームをそれぞれoutputとsideOutputに設定します.
this.setOutput(partA, in.getSchema()); //
this.setSideOutputTables(new Table[]{
DataStreamConversionUtil.toTable(getMLEnvironmentId(), partB, in.getSchema())}); //
最後に、SplitStreamOpには2つのメンバー変数があります.
this.outputは訓練集です.
this.SideOutPutは検証セットです.
return this;
0 x 05リファレンス
トレーニングデータ、検証データ、テストデータ分析