jsdoでtensorflow.js その12


概要

jsdoでtensorflow.jsやってみた。
keras風、高級APIを使わないで、coreとか言われるレベルでやってみた。
九九問題、やってみた。

写真

学習

バッチ数: 81
input: 8
隠れ層: 1
ユニット: 60
活性化関数: tanh
output: 7
活性化関数: softmax
オプチマイザー: adam
ロス: softmaxCrossEntropy
エポック数: 9000

サンプルコード

function tob(i, j) {
    var ary = new Array;
    var v = (i + 1) * (j + 1);
    var b = v & 1;
    if (b > 0)
    {
        ary.push(1);
    }
    else
    {
        ary.push(0);
    }
    for (var k = 0 ; k < 6; k++)
    {        
        var b = v & (2 << k);
        if (b > 0)
        {
            ary.push(1);
        }
        else
        {
            ary.push(0);
        }
    }
    return ary;
}
const buffer2 = tf.buffer([81, 7]);
for (var i = 0; i < 9; i++) 
{
    for (var j = 0; j < 9; j++) 
    {
        var l = i * 9 + j;
        var x = tob(i, j);
        for (var k = 0; k < 7; k++)
        {
            buffer2.set(x[k], l, k);
        }
    }
}
const yt = buffer2.toTensor();
function toa(i, j) {
    var ary = new Array;
    var v = i + 1;
    var b = v & 1;
    if (b > 0)
    {
        ary.push(1);
    }
    else
    {
        ary.push(0);
    }
    for (var k = 0 ; k < 3; k++)
    {        
        var b = v & (2 << k);
        if (b > 0)
        {
            ary.push(1);
        }
        else
        {
            ary.push(0);
        }
    }
    var v = j + 1;
    var b = v & 1;
    if (b > 0)
    {
        ary.push(1);
    }
    else
    {
        ary.push(0);
    }
    for (var k = 0 ; k < 3; k++)
    {        
        var b = v & (2 << k);
        if (b > 0)
        {
            ary.push(1);
        }
        else
        {
            ary.push(0);
        }
    }
    return ary;
}
const buffer = tf.buffer([81, 8]);
for (var i = 0; i < 9; i++) 
{
    for (var j = 0; j < 9; j++) 
    {
        var l = i * 9 + j;
        var x = toa(i, j);
        for (var k = 0; k < 8; k++)
        {
            buffer.set(x[k], l, k);
        }
    }

}
const xt = buffer.toTensor();
var num = 60;
const w1 = tf.variable(tf.randomNormal([8, num]));
const b1 = tf.variable(tf.randomNormal([num]));
const w3 = tf.variable(tf.randomNormal([num, 7]));
const b3 = tf.variable(tf.randomNormal([7]));
function func(x) {
    const h1 = tf.tanh(x.matMul(w1).add(b1));
    return tf.softmax(h1.matMul(w3).add(b3));
}
function loss(pred, ypred) {
    return tf.losses.softmaxCrossEntropy(pred, ypred).mean();
}
const optimizer = tf.train.adam(0.01);
var cc;
for (let i = 0; i < 9001; i++)
{
    const cost = optimizer.minimize(() => loss(func(xt), yt), true);
    cc = cost;    
}
//document.write(func(xt));
var pre = func(xt);
var p = pre.dataSync();
var col;
var row;
document.write('<table>');
var lim = 0.15;
for (row = 0; row < 10; row++)
{
    document.write('<tr>');
    for (col = 0; col < 10; col++)
    {
        if (col === 0 && row === 0)
        {
            document.write('<th>&nbsp;<\/th>');
        }    
        else if (col === 0 && row !== 0)
        {
            document.write('<th>' + row + '<\/th>');
        }
        else if (row === 0)
        {
            document.write('<th>' + col + '<\/th>'); 
        }
        else
        {
            var i = (row - 1) * 9 + (col - 1);
            var v = 0;
            if (p[i * 7 + 0] > lim) v += 1;
            if (p[i * 7 + 1] > lim) v += 2;
            if (p[i * 7 + 2] > lim) v += 4;
            if (p[i * 7 + 3] > lim) v += 8;
            if (p[i * 7 + 4] > lim) v += 16;
            if (p[i * 7 + 5] > lim) v += 32;
            if (p[i * 7 + 6] > lim) v += 64;
            document.write('<td>' + v + '<\/td>'); 
        }
    }
    document.write('<\/tr>');
}
document.write('<\/table>');



成果物

以上。