TensorFlow.jsのGetting Startedをやってみた


前回
* Jupyter Notebookのようなもの(あくまで個人の感想です)をメモ帳・ブラウザ・gitで実現する
で作成した、Notebookを使って、TensorFlow.jsのGetting Startedを行ってみました。

このサンプルは、一次関数の線形回帰です。
オリジナルは、DevToolsのログに出力していましたが、ビジュアライザーに渡せるように変更しました。

10回の試行では(当たり前ですが)十分に収束しません。概ね300回程度で近い値を返すようになります。
ファイルの拡張子はhtmlで作成しています。

実行結果

コード

タイトル・章

<!DOCTYPE html><head><meta charset="UTF-8">
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]"></script>
</head><body><script type="text/markdown">

# JS Notebook

## TensorFlow.js Getting Started

[https://js.tensorflow.org/#getting-started](https://js.tensorflow.org/#getting-started)

数式の表示

### #1 Linear regression

%%%(Math """
y = ax + b
""")
ビジュアライザーの定義

%%%($last ($let visualizer-1 (-> (values)

"""Markdown
| x                 | y                 |
|------------------:|------------------:| %%%($=for values """
| %%%($get $data 0) | %%%($get $data 1) | """)

%%%($local (
        (xy ($map values (-> (r) (# (x ($get r 0)) (y ($get r 1)) ))))
        (get-color (-> (i op)
            ($let p ($to-string op))
            ($let c ($list ($concat "rgba(255,  99, 132, " p ")")
                           ($concat "rgba( 54, 162, 235, " p ")")
                           ($concat "rgba(255, 206,  86, " p ")")
                           ($concat "rgba( 75, 192, 192, " p ")")
                           ($concat "rgba(153, 102, 255, " p ")")
                           ($concat "rgba(255, 159,  64, " p ")") ))
            ($get c ($mod i ($length c))) )))
(Chart (@ (width 800)
             (height 400)
             (unit "px")
             (asImgTag)
             (settings (#
    (type "scatter")
    (data (#
        (label "values")
        (datasets ($list (#
            (label "points")
            (data xy)
            (backgroundColor (get-color 0 0.2))
            (borderColor     (get-color 0 1.0))
            (borderWidth 1)
        )))
    ))
))) )) """)) nil)
コードブロックの定義開始

%%%(Notebook """Js@{(module "Tfjs01") (visualizer visualizer-1)}
```javascript
コードブロック
// Define a model for linear regression.
const model = tf.sequential();
model.add(tf.layers.dense({
    units: 1,         // dimensionality of the output space.
    inputShape: [1],  // dimension of the input.
}));

// Prepare the model for training: Specify the loss and the optimizer.
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});

const trainX = [1, 2, 3, 4];
const trainY = [1, 3, 5, 7];

// Generate some synthetic data for training.
const xs = tf.tensor2d(trainX, [trainX.length, 1]);    // Create a r4xc1 [[1],[2],[3],[4]] 2d tensor.
const ys = tf.tensor2d(trainY, [trainY.length, 1]);    // Create a r4xc1 [[1],[3],[5],[7]] 2d tensor.

// Train the model using the data.
//   train 'epochs' times...
return model.fit(xs, ys, {epochs: 10}).then(() => {
    // End training !
    // Use the model to do inference on a data point the model hasn't seen before.
    const inputX = [5, 6, 7, 8];
    return (
        model.predict(                                 // Generates output predictions for the input samples.
            tf.tensor2d(inputX, [inputX.length, 1])    // Create a r4c1 [[5],[6],[7],[8]] 2d tensor.
        ).data().then(a => {
            // a is an object with string keys "0", "1", ....
            const t = trainX.map((v, i) => [v, trainY[i]]);
            const r = [];
            for (let i in a) {
                if (a.hasOwnProperty(i)) {
                    r[Number(i)] = [inputX[Number(i)], a[i]];
                }
            }
            return t.concat(r);
        })
    );
});
コードブロック定義の終了
&#96;&#96;&#96;   <--- バッククォート × 3 です。
""")
数式表示のためのMathJax読み込み
### To view math formula in all browsers, load [MathJax](https://www.mathjax.org/) script file.

%%%(script (@ (src "https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-MML-AM_CHTML") (crossorigin "anonymous") (async)))

</script>
<script src="js/notebook.js"></script>
<script src="js/menneu.min.js" onload="start({title: 'My Notebook 1'})"></script>
</body>

感想

TensorFlowは、値の取り扱いが非同期処理であるため、苦労がありました。

前回記事公開のタイミングでは、結果が返るのがDOM構築の後になるため、非同期処理の結果を表示できませんでしたが、コードブロックのコンポーネントにビジュアライザーの関数を渡すことで解決できました。(非同期のリソースを待つメカニズム自体は元々、画像の1ファイルへのパッキング等のために持っていたので、そのまま使用できました。)