C#実現KNNアルゴリズム
KNNアルゴリズムのC#コードは、前回ブログのC#がKDツリーを作成するプログラムのアルゴリズムでMATLABのKDTreeを模したプログラム構想です
今回は李航先生の「統計学習方法」の考え方に沿って、KDツリーの分割を作成する次元は輪尋ではなく、データの範囲で探しています.
今回は李航先生の「統計学習方法」の考え方に沿って、KDツリーの分割を作成する次元は輪尋ではなく、データの範囲で探しています.
using System;
using System.Collections.Generic;
using System.Linq;
namespace KNNSearch
{
///
/// Description of KNN.
///
public class Knn
{
///
///
///
private int leafnum = 1;
///
///
///
private List _nodeNames = new List
{
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
"I",
"J",
"K",
"L",
"M",
"N",
"O",
"P",
"Q",
"R",
"S",
"T",
"U",
"V",
"W",
"X",
"Y",
"Z"
};
private List GeneralRawData(int num)
{
List rawData = new List();
Random r = new Random(1);
for (var i = 0; i < num; i++)
{
rawData.Add(new Point() { X = r.NextDouble(), Y = r.NextDouble(), Z = r.NextDouble(), ID = i });
}
//PrintListData(rawData);
return rawData;
}
///
/// KD
///
///
///
private Node CreateKdTree(List data)
{
//
Node root = new Node {NodeData = data};
//
// ,
if (data.Count <= leafnum)
{
if (data.Count == 0)
{
return null;
}
root.LeftNode = null;
root.RightNode = null;
root.Point = data[0];
root.Splitaxis = -1;
root.Name = "AA";
//_nodeNames.RemoveAt(0);
//Console.WriteLine(" {0}, {1}",root.Name, root.NodeData[0].ID);
return root;
}
//
int splitAxis = GetSplitAxis(data);
//
Tuple, List> dataSplit = GetSplitNum(data, splitAxis);
root.Splitaxis = splitAxis;
root.Point = dataSplit.Item1;
root.Name = "AA";
//_nodeNames.RemoveAt(0);
root.LeftNode = CreateKdTree(dataSplit.Item2);
root.RightNode = CreateKdTree(dataSplit.Item3);
return root;
}
private Tuple, List> GetSplitNum(List data, int splitAxis)
{
// splitAxis
var data0 = data.OrderBy(x => Dict[splitAxis](x)).ToList();
int half = data0.Count / 2;
List leftdata = new List();
List rightdata = new List();
for (int i = 0; i < data0.Count; i++)
{
if (i < half)
{
leftdata.Add(data0[i]);
}
else if (i > half)
{
rightdata.Add(data0[i]);
}
}
//Console.WriteLine("Split Axis: {0}", splitAxis);
//PrintListData(data0);
return new Tuple, List>(data0[half], leftdata, rightdata);
}
///
///
///
///
///
private int GetSplitAxis(List data)
{
// ( , , )
List ranges = new List();
for (int i = 0; i < 3; i++)
{
var i1 = i;
var xxxData = data.Select(item => Dict[i1](item));
var enumerable = xxxData as double[] ?? xxxData.ToArray();
ranges.Add(enumerable.Max() - enumerable.Min());
}
var sorted = ranges.Select((x, i) => new KeyValuePair(x, i)).OrderByDescending(x => x.Key).ToList();
return sorted.Select(x => x.Value).ToList()[0];
}
///
/// KNN
///
///
///
///
private Node KdTreeFindNearest(Node tree, Point target)
{
//
List searchPath = new List();
//
Node searchNode = tree;
//(1) , KD
while (searchNode != null)
{
//
searchPath.Add(searchNode);
var splitAxis = searchNode.Splitaxis;
// , ,
searchNode = splitAxis < 0 ? null : Dict[splitAxis](target) <= Dict[splitAxis](searchNode.Point) ? searchNode.LeftNode : searchNode.RightNode;
}
// (2)
//
Node nearestPoint = searchPath[searchPath.Count - 1];
//
double dist = NearestDist(nearestPoint.NodeData, target);
//
searchPath.Remove(nearestPoint);
// (3).
while (searchPath.Count > 0)
{
var backNode = searchPath[searchPath.Count - 1]; //
//(a) ,
if (dist > NearestDist(backNode.NodeData, target))
{
nearestPoint = backNode;
dist = NearestDist(backNode.NodeData, target);
// ,
var splitaxis = backNode.Splitaxis;
//
var distTargetToBound = Math.Abs(Dict[splitaxis](target) - Dict[splitaxis](backNode.Point));
// , ,
//
if (distTargetToBound < dist)
{
// ,
searchNode = Dict[splitaxis](target) < Dict[splitaxis](backNode.Point) ? backNode.RightNode : backNode.LeftNode;
searchPath.Add(searchNode);
}
}
searchPath.Remove(backNode);
}
return nearestPoint;
}
private static Dictionary> Dict => new Dictionary>
{
{ 0, p => p.X },
{ 1, p => p.Y },
{ 2, p => p.Z },
};
public List NodeNames { get => _nodeNames; set => _nodeNames = value; }
///
///
///
///
///
///
private double NearestDist(List nodeData, Point target)
{
List ss = nodeData.Select(item => Math.Sqrt(Math.Pow(item.X - target.X, 2) +
Math.Pow(item.Y - target.Y, 2) +
Math.Pow(item.Z - target.Z, 2)))
.ToList();
return nodeData.Select(item => Math.Sqrt(Math.Pow(item.X - target.X, 2) +
Math.Pow(item.Y - target.Y, 2) + Math.Pow(item.Z - target.Z, 2))).ToList().Min();
}
private void PrintListData(List data)
{
Console.WriteLine("****************");
foreach (Point point in data)
{
Console.WriteLine(point);
}
}
public Knn()
{
List rawData = GeneralRawData(180);
Node node = CreateKdTree(rawData);
Point target = new Point() {X = 0.5, Y = 0.5, Z = 0.5};
Node nd = KdTreeFindNearest(node, target);
//
double nearestDistFromKnn = NearestDist(nd.NodeData, target);
Console.WriteLine(" KNN {0:F3}", nearestDistFromKnn);
double nearestDistFromLoop = NearestDist(rawData, target);
Console.WriteLine(" KNN {0:F3}", nearestDistFromLoop);
}
}
///
/// Description of Node.
///
public class Node
{
///
///
///
public string Name;
///
///
///
public Point Point;
///
///
///
public Node LeftNode;
///
///
///
public Node RightNode;
///
///
///
public List NodeData;
///
///
///
public int Splitaxis;
}
public class Point
{
public double X;
public double Y;
public double Z;
public int ID; // debug
public override string ToString()
{
return $"({X},{Y},{Z},{ID})";
}
}
}