jsdoでtensorflow.js その3


概要

jsdoでtensorflow.jsやってみた。
keras風、高級APIを使わないで、coreとか言われるレベルでやってみた。
sin問題、やってみた。
バイアスとウェイトを取り出してみた。

写真

バイアスとウェイト

Tensor
    [1.1656082, -2.0180476, -6.9387307, 3.3653135, -1.7032033, -2.8969588, 0.0565524, -9.5316219, -5.7146306, -2.8060255, -0.3085915, -6.3642092]
Tensor
     [[0.6203325, -1.027319, 1.8841791, -3.2817535, 1.2214272, 0.9679314, 1.4196726, 7.171823, 1.3714954, 3.7268815, 0.0598045, 1.1944624],]
Tensor
    [-0.2928229]
Tensor
    [[-1.2655445],
     [-1.0401167],
     [0.2071615 ],
     [-0.4695872],
     [-2.6633599],
     [-0.5835096],
     [1.2490041 ],
     [1.2825323 ],
     [-1.5088843],
     [0.6767025 ],
     [0.0224434 ],
     [1.6849496 ]]

サンプルコード

var canvas = document.getElementById("canvas");
var ctx = canvas.getContext("2d");
function draw(data, n) {
    var hc = n * 100 + 100;
    ctx.strokeStyle = "#f00";
    ctx.lineWidth = 1;
    ctx.moveTo(0, hc);
    for (var i = 1; i < 20; i++) 
    {
        ctx.lineTo(i * 10, hc - data[i] * 30);
    }
    ctx.stroke();
}
const buffer = tf.buffer([20, 1]);
const buffer2 = tf.buffer([20, 1]);
const tx = [];
const ty = [];
for (var i = 0; i < 20; i++) 
{
    var x = i / 3.0;    
    var y = Math.sin(x);
    tx.push(x);
    ty.push(y);
    buffer.set(x, i, 0);
    buffer2.set(y, i, 0);
}
draw(ty, 0);
const train_x = buffer.toTensor();
const train_y = buffer2.toTensor();
const w1 = tf.variable(tf.randomNormal([1, 12]));
const b1 = tf.variable(tf.zeros([12]));
const w2 = tf.variable(tf.randomNormal([12, 1]));
const b2 = tf.variable(tf.zeros([1]));
const optimizer = tf.train.adam(0.1);
function loss(pred, ypred) {
    return pred.sub(ypred).square().mean();
}
function func(x) {
    const h = tf.tanh(x.matMul(w1).add(b1));
    return tf.tanh(h.matMul(w2).add(b2));
}
for (let i = 0; i < 9000; i++) 
{
    optimizer.minimize(() => loss(func(train_x), train_y));
}
const preds = func(train_x).dataSync();
draw(preds, 1);
var out1 = document.getElementById('src');
out1.value += b1 + '\n';
out1.value += w1 + '\n';
out1.value += b2 + '\n';
out1.value += w2 + '\n';



成果物

以上。