K近隣アルゴリズム

13835 ワード

public class KnnTest 

{

    public static void readFileToList(String path, List<List<Double>> list)

    {

        BufferedReader br = null;

        

        try {

            br = new BufferedReader(new FileReader(path));

            while (br.ready()) {

                String line = br.readLine();

                if (line.trim().isEmpty()) {

                    continue;

                }

                String[] tokens = line.split(" ");

                List<Double> box = new ArrayList<Double>();

                

                for (String num : tokens) {

                    box.add(Double.parseDouble(num));

                }

                list.add(box);

            }

        }

        catch (IOException ex) {

            ex.printStackTrace();

        }

    }

    

    

    public static void main(String[] args)

    {

        int length = 2;

        String dataFile = "data.txt"; 

        String testFile = "test.txt";

        

        KNN knn = new KNN();

        

        try {

            List<List<Double>> dataList = new ArrayList<List<Double>>();

            List<List<Double>> testList = new ArrayList<List<Double>>();

            

            readFileToList(dataFile, dataList);

            readFileToList(testFile, testList);

            

            for (List<Double> test : testList) {

                for (Double d : test) {

                    System.out.print(d + " ");

                }

                

                String category = knn.knn(dataList, test, length);

                System.out.println(Math.round(Float.parseFloat(category)));

            }

        }

        catch (Exception ex) {

            ex.printStackTrace();

        }

    }

}





class KNN

{

    private static Comparator<Node> comparator = new Comparator<Node>()

    {

        public int compare(Node n1, Node n2)

        {

            if (n1.getDistans() > n2.getDistans()) {

                return 1;

            }

            return 0;

        }

    };

    

    private int[] getRankNumbers(int n, int max)

    {

        int[] result = new int[n];

        int current = 0;

        

        back: for (int i = 0; i < n; i++) {

            current = (int) (Math.random() * max);

            

            for (int j = 0; j < i; j++) {

                if (current == result[j]) {

                    i--;

                    continue back;

                }

            }

            

            result[i] = current;

        }

        

        return result;

    }

    

    public String knn(List<List<Double>> example, List<Double> test, int k)

    {

        PriorityQueue<Node> pq = new PriorityQueue<Node>(k, comparator);

        int[] rand = getRankNumbers(k, example.size());

        

        for (int i = 0; i < k; i++) {

            List<Double> list = example.get(rand[i]);

            String category = list.get(list.size() - 1).toString();

            Node node = new Node(rand[i], calDistans(test, list), category);

            pq.add(node);

        }

        

        for (int i = 0; i < example.size(); i++) {

            List<Double> list = example.get(i);

            double distans = calDistans(test, list);

            Node node = pq.peek();

            if (node.getDistans() > distans) {

                pq.remove();

                pq.add(new Node(i, distans, list.get(list.size() - 1).toString()));

            }

        }

        

        return getMostCategory(pq);

    }

    

    private String getMostCategory(PriorityQueue<Node> pq)

    {

        Map<String, Integer> rankMapping = new HashMap<String, Integer>(pq.size(), 1);

        

        for (int i = 0; i < pq.size(); i++) {

            Node node = pq.remove();

            String category = node.getCategory();

            if (rankMapping.containsKey(category)) {

                rankMapping.put(category, rankMapping.get(category) + 1);

            }

            else {

                rankMapping.put(category, 1);

            }

        }

        

        int index = -1;

        int count = 0;

        

        Object[] data = rankMapping.keySet().toArray();

        for (int i = 0; i < data.length; i++) {

            if (rankMapping.get(data[i]) > count) {

                index = i;

                count = rankMapping.get(data[i]);

            }

        }

        

        return data[index].toString();

    }

    

    

    public double calDistans(List<Double> list1, List<Double> list2)

    {

        double result = 0.00;

        

        for (int i = 0; i < list1.size(); i++) {

            result += (list1.get(i) - list2.get(i)) * (list1.get(i) - list2.get(i));

        }

        

        return result;

    }

    

    

    static class Node

    {

        private int index;

        private double distans;

        private String category;



        public Node(int index, double distans, String category)

        {

            this.index = index;

            this.distans = distans;

            this.category = category;

        }



        public int getIndex() 

        {

            return index;

        }



        public void setIndex(int index) 

        {

            this.index = index;

        }



        public double getDistans() 

        {

            return distans;

        }



        public void setDistans(double distans) 

        {

            this.distans = distans;

        }



        public String getCategory() 

        {

            return category;

        }



        public void setCategory(String category) 

        {

            this.category = category;

        }

    }    

}