Linear classification

46719 ワード

  1 import java.util.ArrayList;

  2 import java.util.Random;

  3 

  4 

  5 public class LinearClassifier {

  6 

  7     private Vector V;

  8     private double vn;

  9     private double eta;//learning rate, the 3st input parameter

 10     

 11     int interval=10;//the 4st input parameter args[3], how often to check to stop

 12     double test_percent;

 13     int []test_pids;

 14     Vector[] testpoints;

 15     static int most=10000;

 16     

 17     double train_percent;

 18     int []train_pids;

 19     Vector[] trainpoints;

 20     double percent=0;//the percentage of all the data

 21     private Vector median;

 22     

 23     public LinearClassifier(int N) {

 24         V=new Vector(N);

 25         median=new Vector(N);

 26         eta=0.4;

 27         vn=0;

 28     }

 29     /**

 30      *  set V and vn, such as they represent an hyperplane of origin C and normal N.

 31      */

 32      private void set_weights(final Vector C, final Vector N){

 33          V=new Vector(N);

 34          vn=-V.dot(C);

 35      }

 36     /**

 37      * returns the signed distance of point X to the hyperplane represented by (V,vn).

 38      */

 39     private  double signed_dist(final Vector X){

 40         return vn+X.dot(V);

 41     }

 42     /**

 43      *  returns true if X is on the positive side of the hyperplane, 

 44      *  false if X is on the negative side of the hyperplane.

 45      */

 46     public boolean classify(final Vector X){

 47         if(signed_dist(X)>0)return true;

 48         else return false;    

 49     }

 50     /**

 51      * updates (V,vn)

 52      * 

 53      * inSide is true if X, a point from a dataset, should be on the positive side of the hyperplane.

 54      * inSide is false if X, a point from a dataset, should be on the negative side of the hyperplane.

 55      * update weights implements one iteration of the stochastic gradient descent. The learning rate is eta.

 56      * 

 57      */

 58     void update_weights(final Vector X, boolean inSide){

 59         double delt_v=0;

 60         double Fx=Math.tanh(signed_dist(X));

 61          Fx=1-Fx*Fx;

 62     //     System.out.print("***"+Fx+"    "+signed_dist(X)+"  -------  "+inSide+" ");    

 63         

 64          double tempvn=vn;

 65          Vector tempv=new Vector(V);

 66          

 67         double z=0,t=0;

 68         if(inSide)

 69             t=1;

 70         else t=-1;

 71         z=Math.tanh(signed_dist(X));

 72         

 73         double error=0.5*(t-z)*(t-z);    

 74         

 75         for(int i=0;i<V.get_length();i++){

 76             delt_v=eta*(t-z)*X.get(i)*Fx;

 77             V.set(i, V.get(i)+delt_v);

 78         }        

 79         vn+=eta*(t-z)*Fx;

 80         

 81         z=Math.tanh(signed_dist(X));

 82         double errornew=0.5*(t-z)*(t-z);

 83         if(error<errornew)

 84         {System.out.println("!!!!!!"+signed_dist(X));

 85             V=tempv;

 86             vn=tempvn;

 87         }

 88         

 89         

 90         

 91     }

 92     

 93     public void reset(Random rd){

 94         Vector N=new Vector(V.get_length());

 95         N.fill(0);    

 96         for(int i=0;i<N.get_length();i++)

 97             N.set(i,rd.nextGaussian());    

 98         //normalize the vector N

 99         N.mul(1/N.norm());    //N.printvec();

100         set_weights(new Vector(V.get_length()),N);

101         //set_weights(median, N);

102     }

103     /**

104      * to test the 1st and 2st dateset, each of them only have 4 Vector

105      */

106     void test1(Random rd,boolean[] inSide,Vector[] test_point){        

107         reset(rd);

108         int i=0;

109         //check the symbol are all the same or all the different

110         //while(!(check_equil(inSide, test_point)||check_equiloppose(inSide, test_point))){

111             while(!check_equil(inSide, test_point)||i>10000){

112             update_weights(test_point[i%4], inSide[i%4]);        

113             i++;

114         }

115         System.out.println("eta= "+eta+" iteration="+i);

116     }

117     /**

118      * check if the symbol are all the same

119      */

120     boolean check_equil(boolean[] inSide,Vector[] point){

121             for(int i=0;i<point.length;i++)

122             if(inSide[i]!=classify(point[i]))

123                 return false;            

124         return true;

125     }

126     /**

127      * check if the symbol are all opposite/on the contrary

128      */

129     boolean check_equiloppose(boolean[] inSide,Vector[] point){

130         for(int i=0;i<point.length;i++)

131         if(inSide[i]==classify(point[i]))

132             return false;            

133     return true;

134 }

135     /**

136      * Training the dataset by the train points,

137      * using test points to determine when to stop learning

138      */

139     void train(){

140         Random rd =new Random();

141         reset(rd);    

142         int i=0;

143         setpercent();

144         

145         double oldtest=test_percent;

146         double oldtrain=train_percent;

147         while(!stop_learning(oldtest,oldtrain)){

148         //while(true){

149             int a=rd.nextInt(trainpoints.length);

150             if(train_pids[a]>0)

151                 update_weights(trainpoints[a], true);

152             else 

153                 update_weights(trainpoints[a], false);

154             i++;

155             //if(i % interval==0)

156                 setpercent();

157             if(i%interval==interval/2){//update old value in different time

158                 oldtest=test_percent;

159                 oldtrain=train_percent;

160             }

161             

162         //    System.out.println(oldtest+"---"+oldtrain);    

163             if(i>most)break;

164         }

165         

166         System.out.println("iteration= "+i);

167         

168     }

169     /**

170      * set the value from 0-size-2 to be the "Vector", size-1 be the "side"

171      * Separate the dataset into two parts: test points and train points

172      */

173     void get_vector(ArrayList<Vector> p,Random rd){

174         int size=p.size();

175         int dim=p.get(0).get_length();    

176         

177         int test_size=(int) (size*0.3);

178         int train_size=size-test_size;

179         

180         train_pids=new int[train_size];

181         trainpoints=new Vector[train_size];

182         

183         test_pids=new int[test_size];

184         testpoints=new Vector[test_size];

185         

186         for(int i=0;i<test_size;i++){

187             int j=rd.nextInt(size);

188             testpoints[i]=Vector.get_sub_vector(p.get(j), 0, dim-2);//0~size-2

189             test_pids[i]=(int) p.get(j).get(dim-1);

190             p.remove(j);

191             size--;

192             testpoints[i].sub(median);

193         }

194         

195         for(int i=0;i<train_size;i++){

196             trainpoints[i]=Vector.get_sub_vector(p.get(i), 0, dim-2);//0~size-2

197             train_pids[i]=(int) p.get(i).get(dim-1);

198             trainpoints[i].sub(median);

199         }

200         

201         

202     }    

203     /**

204      * stop learning, according to the percentage of accuracy of testing and training points

205      * this function got executed every 10, 100, or more times after doing update

206      */

207     boolean stop_learning(double oldTest,double oldTrain){

208         

209         double d1=Math.abs(oldTest-oldTrain);//old delta value

210         double d2=Math.abs(train_percent-test_percent);//new delta value

211         double d3=train_percent+test_percent;

212         //Guarantee least correct, some parameters that I guess, cann't fit to any dataset

213         if(((d2 >2*d1)||(d2 <0.00001)||train_percent>0.85) && train_percent >0.75 &&test_percent>0.72 &&d3>1.6)

214             return  true;

215         return false;

216     }

217     

218     void setpercent(){

219         test_percent=0;

220         train_percent=0;

221         for(int i=0;i<testpoints.length;i++){

222             if(classify(testpoints[i])==true&&test_pids[i]==1)

223                 test_percent++;

224             if(classify(testpoints[i])==false&&test_pids[i]==0)

225                 test_percent++;

226         }

227         

228 

229         for(int i=0;i<trainpoints.length;i++){

230             if(classify(trainpoints[i])==true&&train_pids[i]==1)

231                 train_percent++;

232             if(classify(trainpoints[i])==false&&train_pids[i]==0)

233                 train_percent++;

234         }

235         

236         percent=test_percent+train_percent;

237         percent/=testpoints.length+trainpoints.length;

238         

239         test_percent/=testpoints.length;

240         train_percent/=trainpoints.length;

241         

242         //System.out.println("testPercent: "+test_percent+"  trainPercent: "+train_percent);    

243     }

244     

245     public static void main(String[] args) {

246 //        System.out.println("test 1: --------------------------");

247 //------ test one --------------------------------------------------------------/

248         Vector test_point[]=new Vector[4];

249         test_point[0]=new Vector(2);test_point[1]=new Vector(2);

250         test_point[2]=new Vector(2);test_point[3]=new Vector(2);

251 

252         test_point[0].set(0, -1);test_point[0].set(1, 1);

253         test_point[1].set(0, 1);test_point[1].set(1, 1);

254         test_point[2].set(0, -1);test_point[2].set(1, -1);

255         test_point[3].set(0, 1);test_point[3].set(1, -1);

256         

257         boolean [] inSide=new boolean[4];

258         inSide[0]=true;

259         inSide[1]=true;

260         inSide[2]=false;

261         inSide[3]=false;

262         

263         LinearClassifier test1=new LinearClassifier(2);

264         

265         test1.eta=0.5;

266         test1.test1(new Random(),inSide,test_point);test1.V.printvec();

267         test1.eta=0.01;

268         test1.test1(new Random(),inSide,test_point);test1.V.printvec();

269         test1.eta=0.001;

270         test1.test1(new Random(),inSide,test_point);test1.V.printvec();

271 //------ test two --------------------------------------------------------------/

272         System.out.println("test 2: ==========================");    

273         test_point[0].set(0, 100);test_point[0].set(1, 101);

274         test_point[1].set(0, 101);test_point[1].set(1, 101);

275         test_point[2].set(0, 100);test_point[2].set(1, 100);

276         test_point[3].set(0, 101);test_point[3].set(1, 100);

277         System.out.println("no Optimization:"+"
too much time 。。。"+"
use median to do the optimization"); 278 // test1.eta=2.5;test1.test1(new Random(),inSide,test_point); 279 ArrayList<Vector> points=new ArrayList<Vector>(); 280 points.add(test_point[0]);points.add(test_point[1]); 281 points.add(test_point[2]);points.add(test_point[3]); 282 283 test1.median=new Vector(Vector.vector_median(points)); 284 test1.median.printvec(); 285 for(int i=0;i<4;i++) 286 test_point[i].sub(test1.median); 287 288 test1.eta=0.5; 289 test1.test1(new Random(),inSide,test_point); 290 test1.eta=0.01; 291 test1.test1(new Random(),inSide,test_point); 292 test1.V.add(test1.median); 293 test1.V.printvec(); 294 295 for(int i=0;i<4;i++) 296 test_point[i].add(test1.median); 297 //------ test three --------------------------------------------------------------/ 298 System.out.println("test 3: =========================="+ 299 "
seperate the dataset into 30% testing part and 70% training part, " + 300 "
using the percentage to determine when to stop the learning,the max iteration is "+most); 301 points.clear(); 302 if(args.length==0) 303 points=Vector.read_data("dataset-2");//the dataset 304 else 305 points=Vector.read_data(args[0]); 306 307 int size=points.get(0).get_length(); 308 309 LinearClassifier lc=new LinearClassifier(size-1); 310 lc.median=Vector.get_sub_vector(Vector.vector_median(points), 0, size-2);; 311 lc.eta=0.4; 312 313 if(args.length>3) 314 lc.interval=new Integer(args[3]); 315 if(args.length>=3) 316 lc.eta=new Double(args[2]); 317 318 lc.get_vector(points,new Random()); 319 lc.train(); 320 lc.setpercent(); 321 System.out.println(" percentage: "+lc.percent+" test: "+lc.test_percent+" train: "+lc.train_percent); 322 323 int idtest[]=new int[lc.testpoints.length]; 324 for(int i=0;i<lc.testpoints.length;i++){ 325 if(lc.classify(lc.testpoints[i])) 326 idtest[i]=0; 327 else idtest[i]=1; 328 lc.testpoints[i].add(lc.median); 329 } 330 331 int idtrain[] =new int[lc.trainpoints.length]; 332 for(int i=0;i<lc.trainpoints.length;i++){ 333 if(lc.classify(lc.trainpoints[i])) 334 idtrain[i]=0; 335 else idtrain[i]=1; 336 lc.trainpoints[i].add(lc.median); 337 } 338 339 lc.V.add(lc.median); 340 lc.V.printvec(); 341 if(args.length<2) 342 { 343 Vector.write_data_withID("out-dataset", lc.testpoints, idtest); 344 Vector.write_data_withID("out-dataset", lc.trainpoints, idtrain,true); 345 } 346 else { 347 Vector.write_data_withID(args[1], lc.testpoints, idtest); 348 Vector.write_data_withID(args[1], lc.trainpoints, idtrain,true); 349 } 350 351 } 352 353 }