kelpnetの作法 その2


概要

kelpnetの作法を調べてみた。
fizzbuzzやってみた。

サンプルコード

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using KelpNet;

namespace ConsoleApp3
{
    class Program
    {
        const int EPOCH = 1000;
        const int N = 100;
        static void Main(string[] args)
        {
            Real[][] trainData = new Real[N][];
            Real[][] trainLabel = new Real[N][];
            for (int i = 0; i < N; i++)
            {
                int a = (i >> 6) % 2;
                int b = (i >> 5) % 2;
                int c = (i >> 4) % 2;
                int d = (i >> 3) % 2;
                int e = (i >> 2) % 2;
                int f = (i >> 1) % 2;
                int g = i % 2;
                trainData[i] = new[] { (Real)a, (Real)b, (Real)c, (Real)d, (Real)e, (Real)f, (Real)g  };
            }
            for (int i = 0; i < N; i++)
            {
                if (i % 15 == 0)
                {
                    trainLabel[i] = new Real[] { 3 };
                }
                else if (i % 5 == 0)
                {
                    trainLabel[i] = new Real[] { 2 };
                }
                else if (i % 3 == 0)
                {
                    trainLabel[i] = new Real[] { 1 };
                }
                else
                {
                    trainLabel[i] = new Real[] { 0 };
                }

            }
            FunctionStack nn = new FunctionStack(new Linear(7, 40, name: "l1 Linear"), new Sigmoid(name: "l1 Tanh"), new Linear(40, 4, name: "l2 Linear"));
            nn.SetOptimizer(new MomentumSGD());
            Console.WriteLine("Train Start...");
            for (int i = 0; i < EPOCH; i++)
            {
                Real loss = 0;
                for (int j = 0; j < N; j++)
                {
                    loss += Trainer.Train(nn, trainData[j], trainLabel[j], new SoftmaxCrossEntropy());
                }
                if (i % 100 == 0)
                {
                    Console.WriteLine("loss:" + loss / N);
                }
            }
            Console.WriteLine("Test Start...");
            for (int j = 0; j < N; j++)
            {
                NdArray result = nn.Predict(trainData[j])[0];
                int resultIndex = Array.IndexOf(result.Data, result.Data.Max());
                if (resultIndex == 3)
                {
                    Console.Write("fizzbuzz");
                }
                else if (resultIndex == 2)
                {
                    Console.Write("buzz");
                }
                else if (resultIndex == 1)
                {
                    Console.Write("fizz");
                }
                else
                {
                    Console.Write(j);
                }
                Console.Write(" ");
            }
            Console.WriteLine("Press any key to exit.");
            Console.ReadKey();
        }
    }
}



結果

以上。