Linear classification
46719 ワード
1 import java.util.ArrayList;
2 import java.util.Random;
3
4
5 public class LinearClassifier {
6
7 private Vector V;
8 private double vn;
9 private double eta;//learning rate, the 3st input parameter
10
11 int interval=10;//the 4st input parameter args[3], how often to check to stop
12 double test_percent;
13 int []test_pids;
14 Vector[] testpoints;
15 static int most=10000;
16
17 double train_percent;
18 int []train_pids;
19 Vector[] trainpoints;
20 double percent=0;//the percentage of all the data
21 private Vector median;
22
23 public LinearClassifier(int N) {
24 V=new Vector(N);
25 median=new Vector(N);
26 eta=0.4;
27 vn=0;
28 }
29 /**
30 * set V and vn, such as they represent an hyperplane of origin C and normal N.
31 */
32 private void set_weights(final Vector C, final Vector N){
33 V=new Vector(N);
34 vn=-V.dot(C);
35 }
36 /**
37 * returns the signed distance of point X to the hyperplane represented by (V,vn).
38 */
39 private double signed_dist(final Vector X){
40 return vn+X.dot(V);
41 }
42 /**
43 * returns true if X is on the positive side of the hyperplane,
44 * false if X is on the negative side of the hyperplane.
45 */
46 public boolean classify(final Vector X){
47 if(signed_dist(X)>0)return true;
48 else return false;
49 }
50 /**
51 * updates (V,vn)
52 *
53 * inSide is true if X, a point from a dataset, should be on the positive side of the hyperplane.
54 * inSide is false if X, a point from a dataset, should be on the negative side of the hyperplane.
55 * update weights implements one iteration of the stochastic gradient descent. The learning rate is eta.
56 *
57 */
58 void update_weights(final Vector X, boolean inSide){
59 double delt_v=0;
60 double Fx=Math.tanh(signed_dist(X));
61 Fx=1-Fx*Fx;
62 // System.out.print("***"+Fx+" "+signed_dist(X)+" ------- "+inSide+" ");
63
64 double tempvn=vn;
65 Vector tempv=new Vector(V);
66
67 double z=0,t=0;
68 if(inSide)
69 t=1;
70 else t=-1;
71 z=Math.tanh(signed_dist(X));
72
73 double error=0.5*(t-z)*(t-z);
74
75 for(int i=0;i<V.get_length();i++){
76 delt_v=eta*(t-z)*X.get(i)*Fx;
77 V.set(i, V.get(i)+delt_v);
78 }
79 vn+=eta*(t-z)*Fx;
80
81 z=Math.tanh(signed_dist(X));
82 double errornew=0.5*(t-z)*(t-z);
83 if(error<errornew)
84 {System.out.println("!!!!!!"+signed_dist(X));
85 V=tempv;
86 vn=tempvn;
87 }
88
89
90
91 }
92
93 public void reset(Random rd){
94 Vector N=new Vector(V.get_length());
95 N.fill(0);
96 for(int i=0;i<N.get_length();i++)
97 N.set(i,rd.nextGaussian());
98 //normalize the vector N
99 N.mul(1/N.norm()); //N.printvec();
100 set_weights(new Vector(V.get_length()),N);
101 //set_weights(median, N);
102 }
103 /**
104 * to test the 1st and 2st dateset, each of them only have 4 Vector
105 */
106 void test1(Random rd,boolean[] inSide,Vector[] test_point){
107 reset(rd);
108 int i=0;
109 //check the symbol are all the same or all the different
110 //while(!(check_equil(inSide, test_point)||check_equiloppose(inSide, test_point))){
111 while(!check_equil(inSide, test_point)||i>10000){
112 update_weights(test_point[i%4], inSide[i%4]);
113 i++;
114 }
115 System.out.println("eta= "+eta+" iteration="+i);
116 }
117 /**
118 * check if the symbol are all the same
119 */
120 boolean check_equil(boolean[] inSide,Vector[] point){
121 for(int i=0;i<point.length;i++)
122 if(inSide[i]!=classify(point[i]))
123 return false;
124 return true;
125 }
126 /**
127 * check if the symbol are all opposite/on the contrary
128 */
129 boolean check_equiloppose(boolean[] inSide,Vector[] point){
130 for(int i=0;i<point.length;i++)
131 if(inSide[i]==classify(point[i]))
132 return false;
133 return true;
134 }
135 /**
136 * Training the dataset by the train points,
137 * using test points to determine when to stop learning
138 */
139 void train(){
140 Random rd =new Random();
141 reset(rd);
142 int i=0;
143 setpercent();
144
145 double oldtest=test_percent;
146 double oldtrain=train_percent;
147 while(!stop_learning(oldtest,oldtrain)){
148 //while(true){
149 int a=rd.nextInt(trainpoints.length);
150 if(train_pids[a]>0)
151 update_weights(trainpoints[a], true);
152 else
153 update_weights(trainpoints[a], false);
154 i++;
155 //if(i % interval==0)
156 setpercent();
157 if(i%interval==interval/2){//update old value in different time
158 oldtest=test_percent;
159 oldtrain=train_percent;
160 }
161
162 // System.out.println(oldtest+"---"+oldtrain);
163 if(i>most)break;
164 }
165
166 System.out.println("iteration= "+i);
167
168 }
169 /**
170 * set the value from 0-size-2 to be the "Vector", size-1 be the "side"
171 * Separate the dataset into two parts: test points and train points
172 */
173 void get_vector(ArrayList<Vector> p,Random rd){
174 int size=p.size();
175 int dim=p.get(0).get_length();
176
177 int test_size=(int) (size*0.3);
178 int train_size=size-test_size;
179
180 train_pids=new int[train_size];
181 trainpoints=new Vector[train_size];
182
183 test_pids=new int[test_size];
184 testpoints=new Vector[test_size];
185
186 for(int i=0;i<test_size;i++){
187 int j=rd.nextInt(size);
188 testpoints[i]=Vector.get_sub_vector(p.get(j), 0, dim-2);//0~size-2
189 test_pids[i]=(int) p.get(j).get(dim-1);
190 p.remove(j);
191 size--;
192 testpoints[i].sub(median);
193 }
194
195 for(int i=0;i<train_size;i++){
196 trainpoints[i]=Vector.get_sub_vector(p.get(i), 0, dim-2);//0~size-2
197 train_pids[i]=(int) p.get(i).get(dim-1);
198 trainpoints[i].sub(median);
199 }
200
201
202 }
203 /**
204 * stop learning, according to the percentage of accuracy of testing and training points
205 * this function got executed every 10, 100, or more times after doing update
206 */
207 boolean stop_learning(double oldTest,double oldTrain){
208
209 double d1=Math.abs(oldTest-oldTrain);//old delta value
210 double d2=Math.abs(train_percent-test_percent);//new delta value
211 double d3=train_percent+test_percent;
212 //Guarantee least correct, some parameters that I guess, cann't fit to any dataset
213 if(((d2 >2*d1)||(d2 <0.00001)||train_percent>0.85) && train_percent >0.75 &&test_percent>0.72 &&d3>1.6)
214 return true;
215 return false;
216 }
217
218 void setpercent(){
219 test_percent=0;
220 train_percent=0;
221 for(int i=0;i<testpoints.length;i++){
222 if(classify(testpoints[i])==true&&test_pids[i]==1)
223 test_percent++;
224 if(classify(testpoints[i])==false&&test_pids[i]==0)
225 test_percent++;
226 }
227
228
229 for(int i=0;i<trainpoints.length;i++){
230 if(classify(trainpoints[i])==true&&train_pids[i]==1)
231 train_percent++;
232 if(classify(trainpoints[i])==false&&train_pids[i]==0)
233 train_percent++;
234 }
235
236 percent=test_percent+train_percent;
237 percent/=testpoints.length+trainpoints.length;
238
239 test_percent/=testpoints.length;
240 train_percent/=trainpoints.length;
241
242 //System.out.println("testPercent: "+test_percent+" trainPercent: "+train_percent);
243 }
244
245 public static void main(String[] args) {
246 // System.out.println("test 1: --------------------------");
247 //------ test one --------------------------------------------------------------/
248 Vector test_point[]=new Vector[4];
249 test_point[0]=new Vector(2);test_point[1]=new Vector(2);
250 test_point[2]=new Vector(2);test_point[3]=new Vector(2);
251
252 test_point[0].set(0, -1);test_point[0].set(1, 1);
253 test_point[1].set(0, 1);test_point[1].set(1, 1);
254 test_point[2].set(0, -1);test_point[2].set(1, -1);
255 test_point[3].set(0, 1);test_point[3].set(1, -1);
256
257 boolean [] inSide=new boolean[4];
258 inSide[0]=true;
259 inSide[1]=true;
260 inSide[2]=false;
261 inSide[3]=false;
262
263 LinearClassifier test1=new LinearClassifier(2);
264
265 test1.eta=0.5;
266 test1.test1(new Random(),inSide,test_point);test1.V.printvec();
267 test1.eta=0.01;
268 test1.test1(new Random(),inSide,test_point);test1.V.printvec();
269 test1.eta=0.001;
270 test1.test1(new Random(),inSide,test_point);test1.V.printvec();
271 //------ test two --------------------------------------------------------------/
272 System.out.println("test 2: ==========================");
273 test_point[0].set(0, 100);test_point[0].set(1, 101);
274 test_point[1].set(0, 101);test_point[1].set(1, 101);
275 test_point[2].set(0, 100);test_point[2].set(1, 100);
276 test_point[3].set(0, 101);test_point[3].set(1, 100);
277 System.out.println("no Optimization:"+"
too much time 。。。"+"
use median to do the optimization");
278 // test1.eta=2.5;test1.test1(new Random(),inSide,test_point);
279 ArrayList<Vector> points=new ArrayList<Vector>();
280 points.add(test_point[0]);points.add(test_point[1]);
281 points.add(test_point[2]);points.add(test_point[3]);
282
283 test1.median=new Vector(Vector.vector_median(points));
284 test1.median.printvec();
285 for(int i=0;i<4;i++)
286 test_point[i].sub(test1.median);
287
288 test1.eta=0.5;
289 test1.test1(new Random(),inSide,test_point);
290 test1.eta=0.01;
291 test1.test1(new Random(),inSide,test_point);
292 test1.V.add(test1.median);
293 test1.V.printvec();
294
295 for(int i=0;i<4;i++)
296 test_point[i].add(test1.median);
297 //------ test three --------------------------------------------------------------/
298 System.out.println("test 3: =========================="+
299 "
seperate the dataset into 30% testing part and 70% training part, " +
300 "
using the percentage to determine when to stop the learning,the max iteration is "+most);
301 points.clear();
302 if(args.length==0)
303 points=Vector.read_data("dataset-2");//the dataset
304 else
305 points=Vector.read_data(args[0]);
306
307 int size=points.get(0).get_length();
308
309 LinearClassifier lc=new LinearClassifier(size-1);
310 lc.median=Vector.get_sub_vector(Vector.vector_median(points), 0, size-2);;
311 lc.eta=0.4;
312
313 if(args.length>3)
314 lc.interval=new Integer(args[3]);
315 if(args.length>=3)
316 lc.eta=new Double(args[2]);
317
318 lc.get_vector(points,new Random());
319 lc.train();
320 lc.setpercent();
321 System.out.println(" percentage: "+lc.percent+" test: "+lc.test_percent+" train: "+lc.train_percent);
322
323 int idtest[]=new int[lc.testpoints.length];
324 for(int i=0;i<lc.testpoints.length;i++){
325 if(lc.classify(lc.testpoints[i]))
326 idtest[i]=0;
327 else idtest[i]=1;
328 lc.testpoints[i].add(lc.median);
329 }
330
331 int idtrain[] =new int[lc.trainpoints.length];
332 for(int i=0;i<lc.trainpoints.length;i++){
333 if(lc.classify(lc.trainpoints[i]))
334 idtrain[i]=0;
335 else idtrain[i]=1;
336 lc.trainpoints[i].add(lc.median);
337 }
338
339 lc.V.add(lc.median);
340 lc.V.printvec();
341 if(args.length<2)
342 {
343 Vector.write_data_withID("out-dataset", lc.testpoints, idtest);
344 Vector.write_data_withID("out-dataset", lc.trainpoints, idtrain,true);
345 }
346 else {
347 Vector.write_data_withID(args[1], lc.testpoints, idtest);
348 Vector.write_data_withID(args[1], lc.trainpoints, idtrain,true);
349 }
350
351 }
352
353 }