jsdoでtensorflow.js その10


概要

jsdoでtensorflow.jsやってみた。
keras風、高級APIを使わないで、coreとか言われるレベルでやってみた。
xor問題、やってみた。

写真

サンプルコード

const xt = tf.tensor2d([[1, 0], [0, 1], [1, 1], [0, 0]], [4, 2]);
const yt = tf.tensor2d([[1, 0], [1, 0], [0, 1], [0, 1]], [4, 2]);

const w1 = tf.variable(tf.randomNormal([2, 6]));
const b1 = tf.variable(tf.zeros([6]));
const w2 = tf.variable(tf.randomNormal([6, 2]));
const b2 = tf.variable(tf.zeros([2]));

function func(x) {
    const h = tf.relu(x.matMul(w1).add(b1));
    return tf.softmax(h.matMul(w2).add(b2));
}

function loss(pred, label) {
    return tf.losses.softmaxCrossEntropy(pred, label).mean();
}

const optimizer = tf.train.sgd(0.1);

for (let i = 0; i < 3000; i++)
{
    const cost = optimizer.minimize(() => loss(func(xt), yt), true);
}

alert(func(xt));

成果物

以上。