TFJSチュートリアルフィットデータ

9916 ワード

紹介する
前にtfjsのコア機能を紹介しましたが、今回はTensorFlowを使用します.jsはカーブを合成データセットにフィットします.多項式を使用してデータを生成し、モデルを訓練して生成データの係数を予測します.
コード#コード#
tfjsの公式コードを実行します:リンクをダウンロードしてローカルで実行yarnがインストールされている場合は、直接実行します:
1
2
yarn
yarn watch

yarnダウンロードアドレス:リンク
win 10,yarn watch時にエラー表示が発生した場合:'NODE_ENV’は内部や外部のコマンドではなく、実行可能なプログラムではありません.jsonの内容:
1
2
3
4
"scripts": {
"watch": "NODE_ENV=development parcel --no-hmr --open index.html ",
"build": "NODE_ENV=production parcel build index.html  --no-minify --public-url ./"
}

次のように変更
1
2
3
4
"scripts": {
"watch": "set NODE_ENV=development && parcel --no-hmr --open index.html ",
"build": "set NODE_ENV=production && parcel build index.html  --no-minify --public-url ./"
}

再実行:yarn watchでWebサイトが表示されます.次のようになります.
曲線のフィット過程が直感的にわかります.
入力データ
このアイコンのデータは三次関数を使用しています
y=ax3+bx2+cx+dy=ax3+bx2+cx+d
我々の任務は,この関数の係数:a,b,c,dがデータに最も適した値であることを学習することである.
ステップ1:変数の設定
まず、変数ごとに乱数を割り当てます.
1
2
3
4
const a = tf.variable(tf.scalar(Math.random()));
const b = tf.variable(tf.scalar(Math.random()));
const c = tf.variable(tf.scalar(Math.random()));
const d = tf.variable(tf.scalar(Math.random()));

ステップ2:モデルの作成
tfjsで式を作成し、predict関数を構築し、xを入力としてyを返します.
1
2
3
4
5
6
7
8
9
function predict(x) {
  // y = a * x ^ 3 + b * x ^ 2 + c * x + d
  return tf.tidy(() => {
    return a.mul(x.pow(tf.scalar(3))) // a * x^3
      .add(b.mul(x.square())) // + b * x ^ 2
      .add(c.mul(x)) // + c * x
      .add(d); // + d
  });
}

ステップ3:モデルのトレーニング
我々の最後のステップは,係数の良好な値を学習するためにモデルを訓練することである.モデルを訓練するには、3つのことを定義する必要があります.
  • 損失関数は、多項式の適合度を測定するデータです.損失値が低いほど多項式がフィットします.
  • オプティマイザは,損失関数の出力に基づいて我々の係数値を修正するアルゴリズムを実行する.オプティマイザの目標は、損失関数の出力値を最小化することです.
  • サイクルトレーニングは、損失を最小限に抑えるためにオプティマイザを反復的に実行します.

  • 損失関数の定義
    損失関数として平均二乗誤差(MSE)を用い,我々のデータセットの各x値の実際のy値と予測y値との差を二乗し,その後,すべての結果アイテムの平均値を取り,MSEを計算する.
    MSE=1N∑t=1N(observedt−predictedt)2MSE=1N∑t=1N(observedt−predictedt)2
    私たちはTensorFlowでjsでMSE損失関数を定義します.以下に示します.
    1
    2
    3
    4
    5
    
    function loss(predictions, labels) {
      //            (   ),       ,     。
      const meanSquareError = predictions.sub(labels).square().mean();
      return meanSquareError;
    }
    

    オプティマイザの定義
    私たちのオプティマイザではランダム勾配降下を用いて(SGD).SGDは、データセットのランダムポイントの勾配を取得し、その値を使用してモデル係数の増加または減少の値を通知する.TensorFlow.jsは、SGDの実行に便利な機能を提供し、損失関数の値を最適化するために呼び出すことができるSGDOptimizerのオブジェクトを返す.学習率は、予測を改善するモデルの調整量を制御する.学習率が低いと、学習プロセスが実行されるより遅く(良い係数を学習するためにより多くの訓練反復が必要である)、高い学習率は学習を加速させるが、モデルが正しい値を中心に振動し、常に過度に矯正される可能性がある.
    以下のコードはSGDオプティマイザを構築し、学習率は0.5である.
    1
    2
    
    const learningRate = 0.5;
    const optimizer = tf.train.sgd(learningRate);
    

    トレーニングサイクルの定義
    損失関数とオプティマイザを定義し,SGDを反復的に実行してモデルの係数を最適化して損失(MSE)を最小化するトレーニングサイクルを構築できるようになった.
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    
    function train(xs, ys, numIterations = 75) {
    
      const learningRate = 0.5;
      const optimizer = tf.train.sgd(learningRate);
    
      for (let iter = 0; iter < numIterations; iter++) {
        optimizer.minimize(() => {
          const predsYs = predict(xs);
          return loss(predsYs, ys);
        });
      }
    }
    

    コードを一歩一歩詳しく検討しましょう.まず、トレーニング関数を定義し、データセットのxとyの値と指定した反復回数を入力します.
    1
    2
    3
    
    function train(xs, ys, numIterations) {
    ...
    }
    

    次に、前節で説明したように、学習率とSGDオプティマイザを定義します.
    1
    2
    
    const learningRate = 0.5;
    const optimizer = tf.train.sgd(learningRate);
    

    最後に,for運転numIterations訓練反復のサイクルを確立した.反復のたびにminimizeオプティマイザを呼び出します.
    1
    2
    3
    4
    5
    6
    
    for (let iter = 0; iter < numIterations; iter++) {
      optimizer.minimize(() => {
        const predsYs = predict(xs);
        return loss(predsYs, ys);
      });
    }
    
    minimizeには、2つのことができる機能が必要です.
  • 1.ステップ2で定義した予測モデル関数を使用して、すべてのx値のy値(predY)を予測します.
  • 2.損失関数を定義するときに定義した損失関数を使用して、予測された平均二乗誤差損失を返します.
  • minimizeは次に、この関数で使用される任意の変数(ここでは係数a,b,c,d)を自動的に調整して、戻り値(我々の損失)を最小限に抑える.我々のトレーニングサイクルを実行すると、a,b,c,dは、SGDを75回反復した後にモデルによって学習される係数値を含む.
    結果
    プログラムは実行を終了し、変数a,b,c,dの最終値を取得し、それらを使用して曲線を描画することができます.
    著者:StevenKe このリンクは次のとおりです.http://www.stevenke.com/2018/tensorflowjs-fitting.html
    本ブログのすべての文章は特別声明のほか、CC BY-NC-SA 3.0ライセンス契約を採用しています.転載は出典を明記してください!
    http://www.stevenke.com/2018/tensorflowjs-fitting.html