TensorFlow.js巻き取り神経ネット手書きデジタル認識


元の住所https://laboo.top/2018/11/21/tfjs-dr/
ソース
digit-recognizer
デモ
https://github-laziji.github.io/digit-recognizer/デモ開始時には約100Mの訓練データをロードしますので、少々お待ちください.
トレーニングセットのサイズを調整して、テスト結果の正確さを観察します.
データソース
データソースとhttps://www.kaggle.com の一つのテーマdigit-recognizerは、これらのテストデータにラベルを付けるように要求される42000条のトレーニングデータ(写真とラベルを含む)と28000条のテストデータ(写真のみを含む)を与える.
ウェブサイトの中には他の機械学習のテーマやデータもたくさんあります.とてもいいトレーナーのところです.
実現する
ここではTensorFlow.jsを使ってこのプロジェクトを実現します.
モデルを作成
畳み込み神経ネットワークの第1層は、入力層であり実行層であり、受信IMAGE_H * IMAGE_Wサイズの白黒画素の最終層は出力層であり、10個の出力ユニットがあり、0-9の10個の値の確率分布、例えばLabel=2であり、[0.02,0.01,0.9,...,0.01]の出力である.
function createConvModel() {
  const model = tf.sequential();

  model.add(tf.layers.conv2d({
    inputShape: [IMAGE_H, IMAGE_W, 1],
    kernelSize: 3,
    filters: 16,
    activation: 'relu'
  }));

  model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
  model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
  model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
  model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
  model.add(tf.layers.flatten({}));

  model.add(tf.layers.dense({ units: 64, activation: 'relu' }));
  model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));

  return model;
}
トレーニングモデル
適切な最適化器と損失関数を選択してモデルをコンパイルした.
async function train() {

  ui.trainLog('Create model...');
  model = createConvModel();
  
  ui.trainLog('Compile model...');
  const optimizer = 'rmsprop';
  model.compile({
    optimizer,
    loss: 'categoricalCrossentropy',
    metrics: ['accuracy'],
  });
  const trainData = Data.getTrainData(ui.getTrainNum());
  
  ui.trainLog('Training model...');
  await model.fit(trainData.xs, trainData.labels, {});

  ui.trainLog('Completed!');
  ui.trainCompleted();
}
テスト
ここでテストデータのセットをテストして、対応するラベルを返します.つまり、10個の出力ユニットの中で確率が一番高い下付きです.
function testOne(xs){
  if(!model){
    ui.viewLog('Need to train the model first');
    return;
  }
  ui.viewLog('Testing...');
  let output = model.predict(xs);
  ui.viewLog('Completed!');
  output.print();
  const axis = 1;
  const predictions = output.argMax(axis).dataSync();
  return predictions[0];
}
私のブログの公式番号に注目してください.