plunkerでtensorflow.js その26


概要

plunkerでtensorflow.jsやってみた。
関数フィッティングやってみた。

写真

サンプルコード



function predict(x) {
  return tf.tidy(() => {
    return a.mul(x.square()).add(b.mul(x)).add(c);
  });
}
function loss(predictions, labels) {
  const meanSquareError = predictions.sub(labels).square().mean();
  return meanSquareError;
}
function train(xs, ys, numIterations) {
  for (let iter = 0; iter < numIterations; iter++) 
  {
    optimizer.minimize(() => {
      const pred = predict(xs);
      return loss(pred, ys);
    });
  }
}
function generateData(numPoints, coeff, sigma = 0.04) {
  return tf.tidy(() => {
    const [a, b, c] = [tf.scalar(coeff.a), tf.scalar(coeff.b), tf.scalar(coeff.c)];
    const xs = tf.randomUniform([numPoints], -1, 1);
    const ys = a.mul(xs.square()).add(b.mul(xs)).add(c).add(tf.randomNormal([numPoints], 0, sigma));
    return {
      xs,
      ys,
    };
  })
}
const a = tf.variable(tf.scalar(Math.random()));
const b = tf.variable(tf.scalar(Math.random()));
const c = tf.variable(tf.scalar(Math.random()));
const learningRate = 0.5;
const optimizer = tf.train.sgd(learningRate);
const numIterations = 100;
const eff = {
  a: 5,
  b: 3,
  c: 2
};
const r1 = generateData(20, eff);
train(r1.xs, r1.ys, numIterations);
const aa = Number(a.dataSync());
const bb = Number(b.dataSync());
const cc = Number(c.dataSync());
var values = [];
for (var i = -1; i < 1; i = i + 0.1)
{
  values.push({
    x: i, 
    y: eff.a * i * i + eff.b * i + eff.c,
    pred: aa * i * i + bb * i + cc
  });
}
const spec = {
  '$schema': 'https://vega.github.io/schema/vega-lite/v2.json',
  'width': 300,
  'height': 300,
  'data': {
    'values': values
  },
  'layer': [{
    'mark': 'point',
    'encoding': {
      'x': {
        'field': 'x', 
        'type': 'quantitative'
      },
      'y': {
        'field': 'y', 
        'type': 'quantitative'
      }
    }
  }, {
    'mark': 'line',
    'encoding': {
      'x': {
        'field': 'x', 
        'type': 'quantitative'
      },
      'y': {
        'field': 'pred', 
        'type': 'quantitative'
      },
      'color': {
        'value': 'tomato'
      }
    }
  }]
};
vegaEmbed('#vis', spec);






成果物

以上。