Alink漫談(七):トレーニングデータセットとテストデータセットをどのように区分するか

12541 ワード

Alink漫談(七):トレーニングデータセットとテストデータセットをどのように区分するか


目次
  • Alink漫談(七):トレーニングデータセットとテストデータセットをどのように区分するか
  • 0 0 x 00要約
  • 0 x 01トレーニングデータセットとテストデータセット
  • 0 x 02 Alinkサンプルコード
  • 0 x 03バッチ
  • 3.1得られた記録数
  • 3.2ランダム選択レコード
  • 3.2.1総記録数
  • を得る.
  • 3.2.2 task毎に選択する記録数
  • を決定する.
  • 3.2.3 task毎選択レコード
  • 3.3トレーニングデータセットとテストデータセット
  • を設定する
  • 0 x 04ストリーム処理
  • 0 x 05参照

  • 0 x 00サマリ


    AlinkはアリババがリアルタイムコンピューティングエンジンFlinkに基づいて開発した次世代機械学習アルゴリズムプラットフォームであり、業界初のバッチアルゴリズム、フローアルゴリズムを同時にサポートする機械学習プラットフォームである.ここでは、Alinkがトレーニングデータセットとテストデータセットをどのように区分するかを示します.

    0 x 01トレーニングデータセットとテストデータセット


    にぶんほう
    一般的に予測分析を行う場合、データは2つの大部分に分けられます.一部はトレーニングデータであり、モデルを構築するために使用され、一部はテストデータであり、モデルを検証するために使用される.
    さんぶんほう
    しかし、モデルの構築過程でもモデル/補助モデルの構築を検証する必要がある場合があります.この場合、訓練データは2つの部分に分けられます.1)訓練データ;2)検証データ(Validation Data).この場合、データは3つの部分に分けられます.
  • トレーニングデータ(Train Data):モデル構築用.
  • 検証データ(Validation Data):オプションで、モデル構築を支援し、繰り返し使用できます.
  • テストデータ(Test Data):モデル構築を検出するために使用され、このデータはモデル検証時にのみ使用され、モデルの精度を評価するために使用されます.モデル構築プロセスでは絶対に使用できません.そうしないと、遷移フィッティングが発生します.

  • Training setは、ANNの重み値など、モデルを訓練したり、モデルパラメータを決定したりするために使用されます.
    Validation setはモデル選択(model selection)を行うために使用され、すなわちANNの構造のようなモデルの最終的な最適化と決定を行う.
    Test setは純粋に訓練されたモデルの普及能力をテストするためである.もちろんtest setはモデルの正確性を保証することはできません.彼は似たようなデータがこのモデルで似たような結果を出すと言っただけです.
    実際の応用
    実際の応用では、一般的にデータセットをtraining setとtest setの2つに分類するだけで、多くの文章はvalidation setには関連していません.私たちもここでは触れません.みんなでよく使うスカリーンのtrain_test_split関数は、マトリクスをランダムにトレーニングサブセットとテストサブセットに分割し、分割されたトレーニングセットテストセットサンプルとトレーニングセットテストセットラベルを返します.

    0 x 02 Alinkサンプルコード


    まず、サンプルコードを示し、次に深く分析します.
    public class SplitExample {
      public static void main(String[] args) throws Exception {
        String url = "iris.csv";
        String schema = "sepal_length double, sepal_width double, petal_length double, petal_width double, category string";
    
        // 
        BatchOperator data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schema);
        SplitBatchOp spliter = new SplitBatchOp().setFraction(0.8);
        spliter.linkFrom(data);
        BatchOperator trainData = spliter;
        BatchOperator testData = spliter.getSideOutput(0);
    
        //  
        CsvSourceStreamOp dataS = new CsvSourceStreamOp().setFilePath(url).setSchemaStr(schema);
        SplitStreamOp spliterS = new SplitStreamOp().setFraction(0.4);
        spliterS.linkFrom(dataS);
        StreamOperator train_data = spliterS;
        StreamOperator test_data = spliterS.getSideOutput(0);
      }
    }
    

    0 x 03バッチ


    SplitBatchOpはバッチを分割する主なクラスであり,DAGを具体的に構築する作業はlinkFromで行われる.
    全体的な考え方は簡単です.
  • サンプリングスケールfraction
  • があると仮定する
  • データセットをパーティション化し、各パーティションのレコード数
  • を並列に計算する.
  • 各パーティションの記録数を累積し、全記録総数totCount
  • を得る.
  • 上から下へのサンプリング総数:numTarget = totCount * fraction
  • 具体的な選択要素は各パーティションで行うので、各パーティションにおいて、このパーティションがサンプリングすべき記録数、例えばn番目のパーティションでサンプリングすべき記録数:task_n_count * fraction
  • をそれぞれ算出する.
  • これらのパーティションの「サンプリングすべき記録数」を累積し、下から計算するサンプリング総数:totSelect = task_1_count * fraction + task_2_count * fraction + ... task_n_count * fraction
  • を得る.
  • numTargetとtotSelectは等しくない可能性があるので、余分なnumTarget - totSelectをあるtaskに追加することをランダムに決定した.
  • は、各task上でサンプリングして特定の記録を得る.

  • 3.1得られた記録数


    データを分割するには、まずデータセットのレコード数を知る必要があります.例えばこのDataSetの記録は1万個?それとも10万個ですか.データセットが大きい可能性があるため、このステップではパラレル処理、すなわちデータをパーティション化し、mapPartition操作によって各パーティション上の要素の数を得ることもできます.
    DataSet> countsPerPartition = DataSetUtils.countElementsPerPartition(rows); // task 
    
    DataSet numPickedPerPartition = countsPerPartition
        .mapPartition(new CountInPartition(fraction)) // 
        .setParallelism(1)
        .name("decide_count_of_each_partition");
    

    各パーティションはtaskに対応するので,taskごとのレコード数を取得したと考えてもよい.
    具体的な仕事はDataSetUtilsです.countElementsPerPartitionで完成しました.戻りタイプは、例えば3番taskが30個のレコードを持っている.
    public static  DataSet> countElementsPerPartition(DataSet input) {
       return input.mapPartition(new RichMapPartitionFunction>() {
          @Override
          public void mapPartition(Iterable values, Collector> out) throws Exception {
             long counter = 0;
             for (T value : values) {
                counter++; // task 
             }
             out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter));
          }
       });
    }
    

    総数を計算する作業は、実は次の段階の演算子で行われます.

    3.2ランダム選択記録


    次の仕事は主にCountInPartitionです.mapPartitionは、taskごとにどれだけのレコードを選択するかをランダムに決定する役割を果たします.
    この時点で並行する必要はありませんので、.setParallelism(1)

    3.2.1総記録数を得る


    各パーティションレコード数を得た後,taskごとのレコード数を巡り,合計レコード数totCount(上から下へ計算した総数)を蓄積した.
    public void mapPartition(Iterable> values, Collector out) throws Exception {
    	long totCount = 0L;
    	List> buffer = new ArrayList<>();
    	for (Tuple2 value : values) { // 
        totCount += value.f1; //f1 Long 
        buffer.add(value);
    	}
      ...
      // 。  
    }
    

    3.2.2 taskごとに選択するレコード数を決定する


    そしてCountInPartition.mapPartition関数ではtaskごとに選択されるレコード数がランダムに決定されます.mapPartitionのパラメータIterable>valuesは、前の段階の結果です.元祖です.
    これらの元祖を結びつけてbufferというリストに記録します.
    buffer = {ArrayList@8972}  size = 4
     0 = {Tuple2@8975} "(3,38)" // 3 task, partition 38 。
     1 = {Tuple2@8976} "(2,0)"
     2 = {Tuple2@8977} "(0,38)"
     3 = {Tuple2@8978} "(1,74)"
    

    システムのtask数はbufferサイズです.
    int npart = buffer.size(); // num tasks
    

    そして、「記録総数」から「ランダムトレーニングデータの個数numTarget」を算出する.例えば総数は1万で、ランダムに20%割り当てるべきで、numTargetは2千であるべきです.この数字は後で使います.
    long numTarget = Math.round((totCount * fraction));
    

    各taskの記録数、例えば上のbufferの38、0、38、または74がeachCountに記録される.
    for (Tuple2 value : buffer) {
        eachCount[value.f0] = value.f1;
    }
    

    各taskでランダムに選択された訓練記録数をeachSelectに記録した.各taskが現在「記録数字*fraction」です.例えば3番taskの記録数は38個で、20%を選ぶべきで、38*20%=8個です.
    そしてこれらのtask自身の「ランダムトレーニング記録数」をさらに加算してtotSelect(下から計算した総数)を得る.
    long totSelect = 0L;
    for (int i = 0; i < npart; i++) {
        eachSelect[i] = Math.round(Math.floor(eachCount[i] * fraction));
        totSelect += eachSelect[i];
    }
    

    このときtotSelectと以前計算したnumTargetには具体的な微細な違いがあり、理論的な数字ですが、上から下へ計算するのと下から計算するのでは、結果が異なる可能性があります.下を見るとわかります.
    numTarget = all count * fraction
    
    totSelect = task_1_count * fraction + task_2_count * fraction + ...
    

    そこで、次のステップでこの微細な出入りを処理すると、remainが得られます.これは、「全体的に算出されたランダム数」numTargetと「すべてのtaskから選択されたランダムな訓練記録数の蓄積」totSelectの差です.
    if (totSelect < numTarget) {
        long remain = numTarget - totSelect;
        remain = Math.min(remain, totCount - totSelect);
    

    ちょうど個数が等しい場合は、正常に割り当てられます.
    if (remain == totCount - totSelect) {
    

    数が異なる場合は、eachSelect配列の任意のレコードに「より多くのremain」を追加することをランダムに決定します.
    for (int i = 0; i < Math.min(remain, npart); i++) {
        int taskId = shuffle.get(i);
        while (eachSelect[taskId] >= eachCount[taskId]) {
              taskId = (taskId + 1) % npart;
        }
        eachSelect[taskId]++;
    }
    

    最後にすべての情報を与えます
    long[] statistics = new long[npart * 2];
    for (int i = 0; i < npart; i++) {
        statistics[i] = eachCount[i];
        statistics[i + npart] = eachSelect[i];
    }
    out.collect(statistics);
    
    //  4 , eachCount, eachSelect
    statistics = {long[8]@9003} 
     0 = 38 //eachCount
     1 = 38
     2 = 36
     3 = 38
       
     4 = 31 //eachSelect
     5 = 31
     6 = 28
     7 = 30
    

    これらの情報はブロードキャスト変数として格納されており,すぐに以下に用いられる.
     .withBroadcastSet(numPickedPerPartition, "counts")
    

    3.2.3 taskごとにレコードを選択する


    CountInPartition.PickInPartition関数では、taskごとにランダムにレコードが選択されます.
    まずtask数と以前に格納されたブロードキャスト変数(すなわち、以前に格納されたばかり)が得られる.
    int npart = getRuntimeContext().getNumberOfParallelSubtasks();
    List bc = getRuntimeContext().getBroadcastVariable("counts");
    

    countとselectを分離します.
    long[] eachCount = Arrays.copyOfRange(bc.get(0), 0, npart);
    long[] eachSelect = Arrays.copyOfRange(bc.get(0), npart, npart * 2);
    

    合計task数を得る
    int taskId = getRuntimeContext().getIndexOfThisSubtask();
    

    自分のtask対応のcountを得る
    long count = eachCount[taskId];
    long select = eachSelect[taskId];
    

    本task対応のレコードを追加し、ランダムにシャッフルして順番を狂わせる
    for (int i = 0; i < count; i++) {
         shuffle.add(i); // count 
    }
    Collections.shuffle(shuffle, new Random(taskId)); // 
    
    // suffle 
    shuffle = {ArrayList@8987}  size = 38
     0 = {Integer@8994} 17
     1 = {Integer@8995} 8
     2 = {Integer@8996} 33
     3 = {Integer@8997} 34
     4 = {Integer@8998} 20
     5 = {Integer@8999} 0
     6 = {Integer@9000} 26
     7 = {Integer@9001} 27
     8 = {Integer@9002} 23
     9 = {Integer@9003} 28
     10 = {Integer@9004} 9
     11 = {Integer@9005} 16
     12 = {Integer@9006} 13
     13 = {Integer@9007} 2
     14 = {Integer@9008} 5
     15 = {Integer@9009} 31
     16 = {Integer@9010} 15
     17 = {Integer@9011} 22
     18 = {Integer@9012} 18
     19 = {Integer@9013} 35
     20 = {Integer@9014} 36
     21 = {Integer@9015} 12
     22 = {Integer@9016} 7
     23 = {Integer@9017} 21
     24 = {Integer@9018} 14
     25 = {Integer@9019} 1
     26 = {Integer@9020} 10
     27 = {Integer@9021} 30
     28 = {Integer@9022} 29
     29 = {Integer@9023} 19
     30 = {Integer@9024} 25
     31 = {Integer@9025} 32
     32 = {Integer@9026} 37
     33 = {Integer@9027} 4
     34 = {Integer@9028} 11
     35 = {Integer@9029} 6
     36 = {Integer@9030} 3
     37 = {Integer@9031} 24
    

    ランダムに選択し、選択したものを並べ替えます
    for (int i = 0; i < select; i++) {
        selected[i] = shuffle.get(i); // select , suffle 
    }
    Arrays.sort(selected); // 
    
    // selected , 30 
    selected = {int[30]@8991} 
     0 = 0
     1 = 1
     2 = 2
     3 = 5
     4 = 7
     5 = 8
     6 = 9
     7 = 10
     8 = 12
     9 = 13
     10 = 14
     11 = 15
     12 = 16
     13 = 17
     14 = 18
     15 = 19
     16 = 20
     17 = 21
     18 = 22
     19 = 23
     20 = 26
     21 = 27
     22 = 28
     23 = 29
     24 = 30
     25 = 31
     26 = 33
     27 = 34
     28 = 35
     29 = 36
    

    選択したデータの送信
    if (numEmits < selected.length && iRow == selected[numEmits]) {
        out.collect(row);
        numEmits++;
    }
    

    3.3トレーニングデータセットとテストデータセットの設定


    outputはトレーニングデータセット、SideOutputはテストデータセットです.この2つのデータセットはいずれもAlink内部でTableタイプであるため、SQL演算子minusAllを直接使用して分割を完了する.
    this.setOutput(out, in.getSchema());
    this.setSideOutputTables(new Table[]{in.getOutputTable().minusAll(this.getOutputTable())});
    

    0 x 04フロー処理


    訓練はSplitStreamOpクラスで行われ,linkFromによってモデルの構築が完了した.
    ストリーム処理はSplitStreamとSelectTransformationの2つのクラスに依存して分割ストリームを完了する.具体的には物理的な操作は確立されていないが,上流演算子が下流演算子とどのように関連し,どのように記録を選択するかに影響を及ぼしている.
    SplitStream  splited = in.getDataStream().split(new RandomSelectorOp(getFraction()));
    

    まず,出力時にどのストリームを選択するかをRandomSelectorOpでランダムに決定する.ここには「a」と「b」の2つの名前が勝手につけられていることがわかります.
    class RandomSelectorOp implements OutputSelector  {
       private double fraction;
       private Random random = null;
       @Override
       public Iterable  select(Row value) {
          if (null == random) {
             random = new Random(System.currentTimeMillis());
          }
          List  output = new ArrayList (1);
          output.add((random.nextDouble() < fraction ? "a" : "b")); // , 
          return output;
       }
    }
    

    次に,その2つのランダムに生成されたストリームを得た.
    DataStream  partA = splited.select("a");
    DataStream  partB = splited.select("b");
    

    最後にこの2つのストリームをそれぞれoutputとsideOutputに設定します.
    this.setOutput(partA, in.getSchema()); // 
    this.setSideOutputTables(new Table[]{
    DataStreamConversionUtil.toTable(getMLEnvironmentId(), partB, in.getSchema())}); // 
    

    最後に、SplitStreamOpには2つのメンバー変数があります.
    this.outputは訓練集です.
    this.SideOutPutは検証セットです.
    return this;
    

    0 x 05リファレンス


    トレーニングデータ、検証データ、テストデータ分析