UnityとF#で機械学習③:線形回帰と結果のプロット


はじめに

"F# for Machine Learning Essentials" (Sudipta Mukherjee 著 以下「原著」) の第2章 "Linear Regression"(線形回帰)で出てくる、線形回帰と結果のプロットの話です。ここでは、線形回帰とは二つの変数の関係を最もよく表す直線を見つける事とお考えください。

F#での線形回帰の計算はMath.NET for Numericsを使えばよいのですが、結果のプロットはUnityの標準機能ではできません。そこで、今回はGraph Makerという有料のアセットを使ってみます。F#のところだけなら有料アセット無しでも実行できるので、アセットを買いたくない方は、そちらだけお試しください。

参考資料

原著のコード こちらにある。

Wikipedia 線形回帰

F#側のコード

原著のコードがわかりづらい表記だったので、シンプルに書き直しました。SimpleRegression.Fitの使い方が原著と異なるのに注意してください(原著のままだと、F#3.0ではエラーになる)。

namespace LinearRegressionLibrary

open MathNet.Numerics.LinearAlgebra
open MathNet.Numerics.LinearRegression
open System.IO

module LinearRegress = 

    let rnd = System.Random()
    let genX n =
        List.init n (fun _ -> float(rnd.Next (0,100))) //nで指定される数だけランダムに0~100の値を作る
    let xList = genX 20 //ランダムな値を20個用意
    let yArray =  xList |> List.map ( fun t -> float (t*2.0+float(rnd.Next(0,20))*0.1))
                            |> List.toArray //作ったxListを用いて、 yArray = x*2+noise[0,2] という形でyArrayを作る
    let xArray = xList |>List.toArray
    let B0= SimpleRegression.Fit(xArray,yArray) //直線 y = a + b*x をフィットさせる. B0の最初のアイテムがa, 次のが bに相当。 
    let regressionPairs = xArray |> Array.map ( fun t -> (t, B0.Item1 + B0.Item2* t )) // フィットさせた直線のパラメータをもとに、xArrayの値からの計算を行う。

フィットさせるデータは y = 2x に 0~2 (平均すれば1)のノイズを加算したものを使うので、フィットさせた式は y = 2x + 1に近くなるはずです。

Unity側のコード

計算結果を見てみる

とりあえず、以下の形で計算結果を見てみます。

ShowResult.cs
using UnityEngine;
using LinearRegressionLibrary;


public class ShowResult : MonoBehaviour {
    // Use this for initialization
    void Start () {
        Debug.Log(LinearRegress.B0.Item1);
        Debug.Log(LinearRegress.B0.Item2);
    }
}

元データーに乱数を使ってるので、結果は毎回変わりますが、大体こんな感じになります。

グラフで確認

原著では計算結果をFsPlotというF#用のラッパーを使ってグラフを表示するのですが、Unityでは有料アセットのGraph Makerをつかいます。かなり癖がありマニュアルもイマイチで使いづらいのですが、他に選択肢がないのでしょうがないですね。

まず、新しいシーンを作ってください。シーンにCanvasオブジェクトを作成し、それにGraph_Maker/Prefabs/GraphsのScatterPlotプレハブを子オブジェクトとして配置します。ここで、とりあえずPlayします。以下のようなグラフがゲーム画面に出るはずです。

ここのデータは、プレハブにもともとあるデータなので、これを置き換えてF#からの結果を表示できるようにします。

そのために、まず、Canvasの下に置いたScatterPlotオブジェクトの下層にあるSeries1, Series2を修正します。ここでは、Series1のほうでは元データを点として表示し、Series2のほうではフィットさせた直線を計算したパラメータから描画するようにしましょう。

Series1,2いずれもWMG_Seriesというコンポーネントがあり、その中のPoint Valuesというのがプロットする値をListの形で持ってます。ただ、これを直接触るとうまくいかないので、以下の方法が必要です。

Series1オブジェクトに、WMG_Data_Sourceコンポーネントをアタッチします。そして、WMG_SeriesのMiscタブを押し、その中のPoint Value Data SourceにWMG_Data_Sourceをドラッグ&ドロップします。そして、WMG_Data_Sourceの設定を下図のようにしてください。

重要なのはVariable Namesです。ここで指定した変数名の変数に、プロットしたいデータを格納する必要があります。ここでは、ObservedDataというフィールドにSeries1のためのデータ(Vector2型)を格納することを宣言しています。

Series2オブジェクトについても、同様に、以下の通り設定します。

こちらでは、FittedDataというフィールドにSeries2のためのデータ(Vector2型)を格納することにします。

次に、F#側の結果を保存し、WMG_Data_Sourceへ値を渡すためのスクリプトとして、PlotResultScript.csを作ります。スクリプトは(どこでもよいのですが)ここではScatterPlotにアタッチしてください。

PlotResultScript.cs
using UnityEngine;
using System;
using System.Linq;
using System.Collections.Generic;
using LinearRegressionLibrary;

public class PlotResultScript : MonoBehaviour {

    public WMG_Series _WMG_Series;
    public WMG_Data_Source _WMG_Data_Source;

    public WMG_Series _WMG_Series2; 
    public WMG_Data_Source _WMG_Data_Source2;

    public List<Vector2> ObservedData;
    public List<Vector2> FittedData;

    void Start () {

        ObservedData = LinearRegress.xArray.Select((t,index) => new Vector2(Convert.ToSingle(t),Convert.ToSingle(LinearRegress.yArray[index]))).ToList();

        _WMG_Data_Source.setDataProvider(this);
        _WMG_Series.UpdateFromDataSource();

        FittedData = new List<Vector2> { new Vector2(0, Convert.ToSingle(LinearRegress.B0.Item1)), new Vector2(100, Convert.ToSingle(LinearRegress.B0.Item1)+100* Convert.ToSingle(LinearRegress.B0.Item2)) };

        _WMG_Data_Source2.setDataProvider(this);
        _WMG_Series2.UpdateFromDataSource();
    }
}

WMG_Series, WMG_Data_Sourceについてはインスペクタ上でSeries1オブジェクトを, WMG_Series2, WMG_Data_Source2についてはSeries2オブジェクトをドラッグ&ドロップしてください。

_WMG_Data_Source.setDataProvider(this) で、データを保持しているPlotResultScriptコンポーネント(つまりthis)をDataProviderとしてWMG_Data_Sourceに渡し、WMG_Series が新しいデータを使ってグラフをアップデートするようUpdateFromDataSource()をコールします。

フィットさせるデータ(Series1)については点でプロットさせます。フィットさせた数式の直線(Series2)のほうは、は、数式をそのままグラフにできないためx=0, X=100の時の数式の値を得られたパラメータをもとに計算して、二点で直線を引きます。

このままだとプロットする点がはみ出してしまうので、さらに、以下の通り、ScatterPlotオブジェクトのWMG_Axis_Graphで、Axis Max Valueを200に増やしておきます。

これで実行すると、点と直線がプロットされます

ちなみに、グラフに点を出すか、線を出すか、はWMG_Seriesの"Hide Points", "Hide Lines"で変更できます。

最後に、ScatterPlotをプレハブ化しておくとよいでしょう。