k-meansアルゴリズムの簡単な実装

6600 ワード

  k-means , , , , , :

package com.lele;

import java.util.ArrayList;
import java.util.List;

public class Kmeans {
    
    private static int K = 5; //   ( )     2 
    
    private static int TOTAL = 20; //         20
    private int test = 0;
    private Point[] unknown = new Point[TOTAL]; //    
    private int[] type = new int[TOTAL]; //        ( )
    private Point[] z = new Point[K];//         
    private Point[] z0 = new Point[K]; //           
    private Point sum=null;
    private double[] Distance=new double[TOTAL];
    private int temp = 0;
    private int I = 0; //     
   
    /** Creates a new instance of Kmeans */
    public Kmeans() {
        /**        20   */
        unknown[0] = new Point(5,5);
        unknown[1] = new Point(7,9);
        unknown[2] = new Point(0,1);
        unknown[3] = new Point(1,1);
        unknown[4] = new Point(2,1);
        unknown[5] = new Point(1,2);
        unknown[6] = new Point(2,2);
        unknown[7] = new Point(3,2);
        unknown[8] = new Point(6,6);
        unknown[9] = new Point(7,6);
        unknown[10] = new Point(8,6);
        unknown[11] = new Point(6,7);
        unknown[12] = new Point(7,7);
        unknown[13] = new Point(8,7);
        unknown[14] = new Point(9,7);
        unknown[15] = new Point(7,8);
        unknown[16] = new Point(8,8);
        unknown[17] = new Point(9,8);
        unknown[18] = new Point(8,9);
        unknown[19] = new Point(9,9);
        
        for(int i = 0;i < TOTAL; i++){
            type[i] = 0;
        }
        for(int i = 0; i < K; i++){
            z[i] = unknown[i]; //      
            z0[i] = new Point(0.0,0.0);
        }
    }
    
    /**          */
    public Point newCenter(int m){
        int n = 0;
        sum=new Point(0,0);
        for(int i = 0;i < TOTAL; i++){
            if(type[i] == m){
                sum.setX(sum.getX() + unknown[i].getX());
                sum.setY(sum.getY() + unknown[i].getY());
                n += 1;
            }
        }
      
        sum.setX(sum.getX() / n);
        sum.setY(sum.getY() / n);
        System.out.println(" "+m+"  "+n+" ");
        return sum;
    }
    
    /**              */
    public boolean isEqual(Point p1,Point p2){
    	System.out.println(p1.getX()+"**********"+p2.getX());
    	System.out.println(p1.getY()+"**********"+p2.getY());
        if(Double.doubleToLongBits(p1.getX()) == Double.doubleToLongBits(p2.getX()) &&
        		Double.doubleToLongBits( p1.getY()) == Double.doubleToLongBits(p2.getY()))
            return true;
        else
            return false;
    }
    
    /**             */
    public static double distance(Point p1,Point p2){
    	 
        return (p1.getX() - p2.getX()) * (p1.getX() - p2.getX()) + 
                (p1.getY() - p2.getY()) * (p1.getY() - p2.getY());
    }
    
    /**     , TOTAL              */
    public void order(){
        int temp=0;
        for(int i = 0; i < TOTAL;i++){
            for(int j = 0; j < K;j++)
                if(distance(unknown[i],z[temp])>distance(unknown[i],z[j]))
                	temp=j;
           type[i]=temp;
           System.out.println(unknown[i].toString()+"   "+temp);
        }
       
    }
    public void main(){
        System.out.println("         :");
        for(int i = 0; i < TOTAL;i++){
            System.out.println(unknown[i]);
            //System.out.println("   , " + i + "   :" + z[i].toString());
        }
        for(int i = 0; i < K;i++)
            System.out.println("   , " + i + "   :" + z[i].toString());
        while(test < K){
        	System.out.println("current test:"+test);
            order();
            for(int i = 0; i < K;i ++){
                z[i] = newCenter(i);
              
                System.out.println(" " + i + "     :" + z[i].toString());
               
                if(isEqual(z[i],z0[i]))
                    test += 1;
                else
                    z0[i] = z[i];
                
            }
            I += 1;
            
            System.out.println("    " + I + "   ");
            System.out.println("    :");
            for(int j = 0;j < K;j++){
                System.out.println(" " + j + "    : ");
                for(int i = 0;i < TOTAL;i++){
                    if(type[i] == j)
                        System.out.println(unknown[i].toString());
                }
            }
        }
        
    }
    /**
     * 
     * @param args
     */
    public static void main(String[] args){
        new Kmeans().main();
    }
}




package com.lele;
/**
 * 
 * @author zhaole609
 * define a point class
 *
 */
public class Point {
    
    private double x = 0;
    private double y = 0;
    
    /** Creates a new instance of Point */
    public Point(double x,double y) {
        this.setX(x);
        this.setY(y);
    }

    public double getX() {
        return x;
    }

    public void setX(double x) {
        this.x = x;
    }

    public double getY() {
        return y;
    }

    public void setY(double y) {
        this.y = y;
    }
    
    public String toString(){
        return "[" + x + "," + y + "]";
    }
    
    /**
    public static void main(String[] args) {
        System.out.println(new Point(3,4).toString());
    }*/
}
 コードは基本的に元と同じですが、途中で少し変わっただけです.