C#実現KNNアルゴリズム


KNNアルゴリズムのC#コードは、前回ブログのC#がKDツリーを作成するプログラムのアルゴリズムでMATLABのKDTreeを模したプログラム構想です
今回は李航先生の「統計学習方法」の考え方に沿って、KDツリーの分割を作成する次元は輪尋ではなく、データの範囲で探しています.
using System;
using System.Collections.Generic;
using System.Linq;


namespace KNNSearch
{
    /// 
    /// Description of KNN.
    /// 
    public class Knn
    {
        /// 
        ///         
        /// 
        private int leafnum = 1;
        /// 
        ///       
        /// 
        private List _nodeNames = new List
        {
            "A",
            "B",
            "C",
            "D",
            "E",
            "F",
            "G",
            "H",
            "I",
            "J",
            "K",
            "L",
            "M",
            "N",
            "O",
            "P",
            "Q",
            "R",
            "S",
            "T",
            "U",
            "V",
            "W",
            "X",
            "Y",
            "Z"
        };
        private List GeneralRawData(int num)
        {
            List rawData = new List();
            Random r = new Random(1);
            for (var i = 0; i < num; i++)
            {
                rawData.Add(new Point() { X = r.NextDouble(), Y = r.NextDouble(), Z = r.NextDouble(), ID = i });
            }
            //PrintListData(rawData);
            return rawData;
        }

        /// 
        ///   KD 
        /// 
        /// 
        /// 
        private Node CreateKdTree(List data)
        {
            //      
            Node root = new Node {NodeData = data};
            //         
            //                     ,          
            if (data.Count <= leafnum)
            {
                if (data.Count == 0)
                {
                    return null;
                }
                root.LeftNode = null;
                root.RightNode = null;
                root.Point = data[0];
                root.Splitaxis = -1;
                root.Name = "AA";
                //_nodeNames.RemoveAt(0);
                //Console.WriteLine("      {0},      {1}",root.Name, root.NodeData[0].ID);
                return root;
            }
            //      
            int splitAxis = GetSplitAxis(data);
            //     
            Tuple, List> dataSplit = GetSplitNum(data, splitAxis);
            root.Splitaxis = splitAxis;
            root.Point = dataSplit.Item1;
            root.Name = "AA";
            //_nodeNames.RemoveAt(0);
            root.LeftNode = CreateKdTree(dataSplit.Item2);
            root.RightNode = CreateKdTree(dataSplit.Item3);
            return root;
        }

        private Tuple, List> GetSplitNum(List data, int splitAxis)
        {
            //       splitAxis  
            var data0 = data.OrderBy(x => Dict[splitAxis](x)).ToList();
            int half = data0.Count / 2;
            List leftdata = new List();
            List rightdata = new List();
            for (int i = 0; i < data0.Count; i++)
            {
                if (i < half)
                {
                    leftdata.Add(data0[i]);
                }
                else if (i > half)
                {
                    rightdata.Add(data0[i]);
                }
            }
            //Console.WriteLine("Split Axis: {0}", splitAxis);
            //PrintListData(data0);
            return new Tuple, List>(data0[half], leftdata, rightdata);
        }
        /// 
        ///        
        /// 
        /// 
        /// 
        private int GetSplitAxis(List data)
        {
            //                (       ,   ,       )
            List ranges = new List();
            for (int i = 0; i < 3; i++)
            {
                var i1 = i;
                var xxxData = data.Select(item => Dict[i1](item));
                var enumerable = xxxData as double[] ?? xxxData.ToArray();
                ranges.Add(enumerable.Max() - enumerable.Min());
            }
            var sorted = ranges.Select((x, i) => new KeyValuePair(x, i)).OrderByDescending(x => x.Key).ToList();
            return sorted.Select(x => x.Value).ToList()[0];
        }

        /// 
        /// KNN  
        /// 
        /// 
        /// 
        /// 
        private Node KdTreeFindNearest(Node tree, Point target)
        {
            //     
            List searchPath = new List();
            //      
            Node searchNode = tree;
            //(1)           ,        KD 
            while (searchNode != null)
            {
                //            
                searchPath.Add(searchNode);
                var splitAxis = searchNode.Splitaxis;
                //               ,       ,         
                searchNode = splitAxis < 0 ? null : Dict[splitAxis](target) <= Dict[splitAxis](searchNode.Point) ? searchNode.LeftNode : searchNode.RightNode;
            }
            // (2)            
            //     
            Node nearestPoint = searchPath[searchPath.Count - 1];
            //       
            double dist = NearestDist(nearestPoint.NodeData, target);
            //      
            searchPath.Remove(nearestPoint);
            // (3).       
            while (searchPath.Count > 0)
            {
                var backNode = searchPath[searchPath.Count - 1]; //     
                //(a)                           ,            
                if (dist > NearestDist(backNode.NodeData, target))
                {
                    nearestPoint = backNode;
                    dist = NearestDist(backNode.NodeData, target);
                    //     ,          
                    var splitaxis = backNode.Splitaxis;
                    //              

                    var distTargetToBound = Math.Abs(Dict[splitaxis](target) - Dict[splitaxis](backNode.Point));
                    //           ,             ,               
                    //          
                    if (distTargetToBound < dist)
                    {
                        //                ,            
                        searchNode = Dict[splitaxis](target) < Dict[splitaxis](backNode.Point) ? backNode.RightNode : backNode.LeftNode;
                        searchPath.Add(searchNode);
                    }
                }
                searchPath.Remove(backNode);

            }
            return nearestPoint;
        }

        private static Dictionary> Dict => new Dictionary>
        {
            { 0, p => p.X },
            { 1, p => p.Y },
            { 2, p => p.Z },

        };

        public List NodeNames { get => _nodeNames; set => _nodeNames = value; }

        /// 
        ///                   
        /// 
        /// 
        /// 
        /// 
        private double NearestDist(List nodeData, Point target)
        {
            List ss = nodeData.Select(item => Math.Sqrt(Math.Pow(item.X - target.X, 2) +
                                                                Math.Pow(item.Y - target.Y, 2) +
                                                                Math.Pow(item.Z - target.Z, 2)))
                .ToList();
            return nodeData.Select(item => Math.Sqrt(Math.Pow(item.X - target.X, 2) +
            Math.Pow(item.Y - target.Y, 2) + Math.Pow(item.Z - target.Z, 2))).ToList().Min();
            
        }

        private void PrintListData(List data)
        {
            Console.WriteLine("****************");
            foreach (Point point in data)
            {
                Console.WriteLine(point);
            }
        }
        public Knn()
        {
            List rawData = GeneralRawData(180);
            Node node = CreateKdTree(rawData);
            Point target = new Point() {X = 0.5, Y = 0.5, Z = 0.5};
            Node nd = KdTreeFindNearest(node, target);
            //      
            double nearestDistFromKnn = NearestDist(nd.NodeData, target);
            Console.WriteLine("  KNN            {0:F3}", nearestDistFromKnn);
            double nearestDistFromLoop = NearestDist(rawData, target);
            Console.WriteLine("  KNN            {0:F3}", nearestDistFromLoop);
        }
    }

    /// 
    /// Description of Node.
    /// 
    public class Node
    {
        /// 
        ///     
        /// 
        public string Name;
        /// 
        ///       
        /// 
        public Point Point;
        /// 
        ///    
        /// 
        public Node LeftNode;
        /// 
        ///    
        /// 
        public Node RightNode;
        /// 
        ///        
        /// 
        public List NodeData;
        /// 
        ///    
        /// 
        public int Splitaxis;
    }

    public class Point
    {
        public double X;
        public double Y;
        public double Z;
        public int ID; // debug 
        public override string ToString()
        {
            return $"({X},{Y},{Z},{ID})";
        }
    }
}