jsdoでtensorflow.js その11


概要

jsdoでtensorflow.jsやってみた。
keras風、高級APIを使わないで、coreとか言われるレベルでやってみた。
7segmentLED問題、やってみた。

写真

LED

7segmentは、以下の配置

   a
f     b
   g
e     c
   d 

学習

バッチ数: 10
input: 10
隠れ層: 1
ユニット: 40
活性化関数: tanh
output: 7
活性化関数: softmax
オプチマイザー: adam
ロス: softmaxCrossEntropy
エポック数: 6000

サンプルコード

var canvas = document.getElementById('canvas');
canvas.width = 500;
canvas.height = 200;
var out = document.getElementById('out');
var context = canvas.getContext('2d');

function draw7SegLED(properties) {
     var context = properties.context;
     var x = properties.x;
     var y = properties.y;
     var w = properties.width;
     var h = properties.height;
     var seg = properties.seg;
     function drawHorizontal(left, top, width, height) {
         var w_ = width;
         var h_ = height / 2;
         var x_ = left;
         var y_ = top + h_;
         context.moveTo(x_ , y_);
         context.lineTo(x_ + 5, y_ - h_);
         context.lineTo(x_ + w_ - 5, y_ - h_);
         context.lineTo(x_ + w_, y_);
         context.lineTo(x_ + w_ - 5, y_ + h_);
         context.lineTo(x_ + 5, y_ + h_);
         context.fill();
     }
     function drawVertical(left, top, width, height) {
         var w_ = width / 2;
         var h_ = height;
         var x_ = left + w_;
         var y_ = top;
         context.moveTo(x_, y_);
         context.lineTo(x_ + w_, y_ + 5);
         context.lineTo(x_ + w_, y_ + h_ - 5);
         context.lineTo(x_, y_ + h_);
         context.lineTo(x_ - w_, y_ + h_ - 5);
         context.lineTo(x_ - w_, y_ + 5);
         context.fill();
     }
     function drawDp(left, top, radius) {
         context.moveTo(left, top);
         context.arc(left, top, radius, 0, Math.PI * 2, false);
         context.fill();
     }
     context.fillStyle = properties.backColor;
     context.beginPath();
     context.fillRect(x, y, w, h);
     context.fillStyle = properties.fontColor;
     var canvasMargin = 5;
     var barMargin = 5;
     var barWeight = 10;
     var a = {
         x: x + canvasMargin,
         y: y + canvasMargin,
         w: w - (canvasMargin * 2) - (barWeight / 2) - barMargin
     };
     var b = {
         x: x + w - canvasMargin - (barMargin * 2) - (barWeight / 2),
         y: y + (barMargin * 2) + canvasMargin,
         h: (h / 2) - (barWeight * 2) - (canvasMargin * 2) + (barMargin * 3) - (barMargin / 2)
     };
     var c = {
         y: y + (barMargin * 2) + (h / 2) - (barWeight * 2) + (barMargin * 3) + (barMargin / 2)
     };
     var e = {
         x: x + canvasMargin - barMargin
     };
     var dp = {
         x: x + w - (barWeight / 2),
         y: y + h - (barWeight / 2),
         radius: barWeight / 2
     };
     if (seg.a) drawHorizontal(a.x, a.y, a.w, barWeight);
     if (seg.b) drawVertical(b.x, b.y, barWeight, b.h);
     if (seg.c) drawVertical(b.x, c.y, barWeight, b.h);
     if (seg.d) drawHorizontal(a.x, y + h - canvasMargin * 2, a.w, barWeight);
     if (seg.e) drawVertical(e.x, c.y, barWeight, b.h);
     if (seg.f) drawVertical(e.x, y + canvasMargin + barMargin * 2, barWeight, b.h);
     if (seg.g) drawHorizontal(a.x, y + h / 2 - barMargin / 2, a.w, barWeight);
     if (seg.dp) drawDp(dp.x, dp.y, dp.radius);                                              
}
const xt = tf.tensor2d([
    [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], 
    [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], 
    [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 1, 0], 
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]], [10, 10]);
const yt = tf.tensor2d([
    [1, 1, 0, 1, 1, 1, 1], 
    [1, 1, 1, 1, 1, 1, 1],
    [0, 1, 0, 0, 1, 1, 1],
    [1, 1, 1, 1, 1, 0, 1],
    [1, 1, 0, 1, 1, 0, 1],
    [1, 1, 0, 0, 1, 1, 0], 
    [1, 0, 0, 1, 1, 1, 1], 
    [1, 0, 1, 1, 0, 1, 1],
    [0, 0, 0, 0, 1, 1, 0], 
    [0, 1, 1, 1, 1, 1, 1]], [10, 7]);
var num = 40;
const w1 = tf.variable(tf.randomNormal([10, num]));
const b1 = tf.variable(tf.randomNormal([num]));
const w2 = tf.variable(tf.randomNormal([num, num]));
const b2 = tf.variable(tf.randomNormal([num]));
const w3 = tf.variable(tf.randomNormal([num, 7]));
const b3 = tf.variable(tf.randomNormal([7]));
function func(x) {
    const h1 = tf.tanh(x.matMul(w1).add(b1));
    return tf.softmax(h1.matMul(w3).add(b3));
}
function loss(pred, ypred) {
    return tf.losses.softmaxCrossEntropy(pred, ypred).mean();
}
const optimizer = tf.train.adam(0.01);
var cc;
for (let i = 0; i < 6001; i++)
{
    const cost = optimizer.minimize(() => loss(func(xt), yt), true);
    cc = cost;    
}
var pre = func(xt);
var p = pre.dataSync();
var l = p.length / 7;
for (var i = 0; i < l; i++)
{
    var a = 0;
    var b = 0;
    var c = 0;
    var d = 0;
    var e = 0;
    var f = 0;
    var g = 0;
    if (p[i * 7 + 0] > 0.1) g = 1;
    if (p[i * 7 + 1] > 0.1) f = 1;
    if (p[i * 7 + 2] > 0.1) e = 1;
    if (p[i * 7 + 3] > 0.1) d = 1;
    if (p[i * 7 + 4] > 0.1) c = 1;
    if (p[i * 7 + 5] > 0.1) b = 1;
    if (p[i * 7 + 6] > 0.1) a = 1;
    draw7SegLED({
        context: context,
        fontColor: 'rgba(255, 255, 0, 1)',
        backColor: 'rgba(34, 34, 34, 1)',
        x: i * 50,
        y: 0,
        width: 50,
        height: 100,
        seg: {
            a: a,
            b: b,
            c: c,
            d: d,
            e: e,
            f: f,
            g: g,
            dp: 0
        }
    });
}

成果物

以上。