<code> Linear classification
1 import java.util.ArrayList; 2 import java.util.Random; 3 4 5 public class LinearClassifier { 6 7 private Vector V; 8 private double vn; 9 private double eta;//learning rate, the 3st input parameter 10 11 int interval=10;//the 4st input parameter args[3], how often to check to stop 12 double test_percent; 13 int []test_pids; 14 Vector[] testpoints; 15 static int most=10000; 16 17 double train_percent; 18 int []train_pids; 19 Vector[] trainpoints; 20 double percent=0;//the percentage of all the data 21 private Vector median; 22 23 public LinearClassifier(int N) { 24 V=new Vector(N); 25 median=new Vector(N); 26 eta=0.4; 27 vn=0; 28 } 29 /** 30 * set V and vn, such as they represent an hyperplane of origin C and normal N. 31 */ 32 private void set_weights(final Vector C, final Vector N){ 33 V=new Vector(N); 34 vn=-V.dot(C); 35 } 36 /** 37 * returns the signed distance of point X to the hyperplane represented by (V,vn). 38 */ 39 private double signed_dist(final Vector X){ 40 return vn+X.dot(V); 41 } 42 /** 43 * returns true if X is on the positive side of the hyperplane, 44 * false if X is on the negative side of the hyperplane. 45 */ 46 public boolean classify(final Vector X){ 47 if(signed_dist(X)>0)return true; 48 else return false; 49 } 50 /** 51 * updates (V,vn) 52 * 53 * inSide is true if X, a point from a dataset, should be on the positive side of the hyperplane. 54 * inSide is false if X, a point from a dataset, should be on the negative side of the hyperplane. 55 * update weights implements one iteration of the stochastic gradient descent. The learning rate is eta. 56 * 57 */ 58 void update_weights(final Vector X, boolean inSide){ 59 double delt_v=0; 60 double Fx=Math.tanh(signed_dist(X)); 61 Fx=1-Fx*Fx; 62 // System.out.print("***"+Fx+" "+signed_dist(X)+" ------- "+inSide+" "); 63 64 double tempvn=vn; 65 Vector tempv=new Vector(V); 66 67 double z=0,t=0; 68 if(inSide) 69 t=1; 70 else t=-1; 71 z=Math.tanh(signed_dist(X)); 72 73 double error=0.5*(t-z)*(t-z); 74 75 for(int i=0;i<V.get_length();i++){ 76 delt_v=eta*(t-z)*X.get(i)*Fx; 77 V.set(i, V.get(i)+delt_v); 78 } 79 vn+=eta*(t-z)*Fx; 80 81 z=Math.tanh(signed_dist(X)); 82 double errornew=0.5*(t-z)*(t-z); 83 if(error<errornew) 84 {System.out.println("!!!!!!"+signed_dist(X)); 85 V=tempv; 86 vn=tempvn; 87 } 88 89 90 91 } 92 93 public void reset(Random rd){ 94 Vector N=new Vector(V.get_length()); 95 N.fill(0); 96 for(int i=0;i<N.get_length();i++) 97 N.set(i,rd.nextGaussian()); 98 //normalize the vector N 99 N.mul(1/N.norm()); //N.printvec(); 100 set_weights(new Vector(V.get_length()),N); 101 //set_weights(median, N); 102 } 103 /** 104 * to test the 1st and 2st dateset, each of them only have 4 Vector 105 */ 106 void test1(Random rd,boolean[] inSide,Vector[] test_point){ 107 reset(rd); 108 int i=0; 109 //check the symbol are all the same or all the different 110 //while(!(check_equil(inSide, test_point)||check_equiloppose(inSide, test_point))){ 111 while(!check_equil(inSide, test_point)||i>10000){ 112 update_weights(test_point[i%4], inSide[i%4]); 113 i++; 114 } 115 System.out.println("eta= "+eta+" iteration="+i); 116 } 117 /** 118 * check if the symbol are all the same 119 */ 120 boolean check_equil(boolean[] inSide,Vector[] point){ 121 for(int i=0;i<point.length;i++) 122 if(inSide[i]!=classify(point[i])) 123 return false; 124 return true; 125 } 126 /** 127 * check if the symbol are all opposite/on the contrary 128 */ 129 boolean check_equiloppose(boolean[] inSide,Vector[] point){ 130 for(int i=0;i<point.length;i++) 131 if(inSide[i]==classify(point[i])) 132 return false; 133 return true; 134 } 135 /** 136 * Training the dataset by the train points, 137 * using test points to determine when to stop learning 138 */ 139 void train(){ 140 Random rd =new Random(); 141 reset(rd); 142 int i=0; 143 setpercent(); 144 145 double oldtest=test_percent; 146 double oldtrain=train_percent; 147 while(!stop_learning(oldtest,oldtrain)){ 148 //while(true){ 149 int a=rd.nextInt(trainpoints.length); 150 if(train_pids[a]>0) 151 update_weights(trainpoints[a], true); 152 else 153 update_weights(trainpoints[a], false); 154 i++; 155 //if(i % interval==0) 156 setpercent(); 157 if(i%interval==interval/2){//update old value in different time 158 oldtest=test_percent; 159 oldtrain=train_percent; 160 } 161 162 // System.out.println(oldtest+"---"+oldtrain); 163 if(i>most)break; 164 } 165 166 System.out.println("iteration= "+i); 167 168 } 169 /** 170 * set the value from 0-size-2 to be the "Vector", size-1 be the "side" 171 * Separate the dataset into two parts: test points and train points 172 */ 173 void get_vector(ArrayList<Vector> p,Random rd){ 174 int size=p.size(); 175 int dim=p.get(0).get_length(); 176 177 int test_size=(int) (size*0.3); 178 int train_size=size-test_size; 179 180 train_pids=new int[train_size]; 181 trainpoints=new Vector[train_size]; 182 183 test_pids=new int[test_size]; 184 testpoints=new Vector[test_size]; 185 186 for(int i=0;i<test_size;i++){ 187 int j=rd.nextInt(size); 188 testpoints[i]=Vector.get_sub_vector(p.get(j), 0, dim-2);//0~size-2 189 test_pids[i]=(int) p.get(j).get(dim-1); 190 p.remove(j); 191 size--; 192 testpoints[i].sub(median); 193 } 194 195 for(int i=0;i<train_size;i++){ 196 trainpoints[i]=Vector.get_sub_vector(p.get(i), 0, dim-2);//0~size-2 197 train_pids[i]=(int) p.get(i).get(dim-1); 198 trainpoints[i].sub(median); 199 } 200 201 202 } 203 /** 204 * stop learning, according to the percentage of accuracy of testing and training points 205 * this function got executed every 10, 100, or more times after doing update 206 */ 207 boolean stop_learning(double oldTest,double oldTrain){ 208 209 double d1=Math.abs(oldTest-oldTrain);//old delta value 210 double d2=Math.abs(train_percent-test_percent);//new delta value 211 double d3=train_percent+test_percent; 212 //Guarantee least correct, some parameters that I guess, cann't fit to any dataset 213 if(((d2 >2*d1)||(d2 <0.00001)||train_percent>0.85) && train_percent >0.75 &&test_percent>0.72 &&d3>1.6) 214 return true; 215 return false; 216 } 217 218 void setpercent(){ 219 test_percent=0; 220 train_percent=0; 221 for(int i=0;i<testpoints.length;i++){ 222 if(classify(testpoints[i])==true&&test_pids[i]==1) 223 test_percent++; 224 if(classify(testpoints[i])==false&&test_pids[i]==0) 225 test_percent++; 226 } 227 228 229 for(int i=0;i<trainpoints.length;i++){ 230 if(classify(trainpoints[i])==true&&train_pids[i]==1) 231 train_percent++; 232 if(classify(trainpoints[i])==false&&train_pids[i]==0) 233 train_percent++; 234 } 235 236 percent=test_percent+train_percent; 237 percent/=testpoints.length+trainpoints.length; 238 239 test_percent/=testpoints.length; 240 train_percent/=trainpoints.length; 241 242 //System.out.println("testPercent: "+test_percent+" trainPercent: "+train_percent); 243 } 244 245 public static void main(String[] args) { 246 // System.out.println("test 1: --------------------------"); 247 //------ test one --------------------------------------------------------------/ 248 Vector test_point[]=new Vector[4]; 249 test_point[0]=new Vector(2);test_point[1]=new Vector(2); 250 test_point[2]=new Vector(2);test_point[3]=new Vector(2); 251 252 test_point[0].set(0, -1);test_point[0].set(1, 1); 253 test_point[1].set(0, 1);test_point[1].set(1, 1); 254 test_point[2].set(0, -1);test_point[2].set(1, -1); 255 test_point[3].set(0, 1);test_point[3].set(1, -1); 256 257 boolean [] inSide=new boolean[4]; 258 inSide[0]=true; 259 inSide[1]=true; 260 inSide[2]=false; 261 inSide[3]=false; 262 263 LinearClassifier test1=new LinearClassifier(2); 264 265 test1.eta=0.5; 266 test1.test1(new Random(),inSide,test_point);test1.V.printvec(); 267 test1.eta=0.01; 268 test1.test1(new Random(),inSide,test_point);test1.V.printvec(); 269 test1.eta=0.001; 270 test1.test1(new Random(),inSide,test_point);test1.V.printvec(); 271 //------ test two --------------------------------------------------------------/ 272 System.out.println("test 2: =========================="); 273 test_point[0].set(0, 100);test_point[0].set(1, 101); 274 test_point[1].set(0, 101);test_point[1].set(1, 101); 275 test_point[2].set(0, 100);test_point[2].set(1, 100); 276 test_point[3].set(0, 101);test_point[3].set(1, 100); 277 System.out.println("no Optimization:"+"\n too much time 。。。"+"\nuse median to do the optimization"); 278 // test1.eta=2.5;test1.test1(new Random(),inSide,test_point); 279 ArrayList<Vector> points=new ArrayList<Vector>(); 280 points.add(test_point[0]);points.add(test_point[1]); 281 points.add(test_point[2]);points.add(test_point[3]); 282 283 test1.median=new Vector(Vector.vector_median(points)); 284 test1.median.printvec(); 285 for(int i=0;i<4;i++) 286 test_point[i].sub(test1.median); 287 288 test1.eta=0.5; 289 test1.test1(new Random(),inSide,test_point); 290 test1.eta=0.01; 291 test1.test1(new Random(),inSide,test_point); 292 test1.V.add(test1.median); 293 test1.V.printvec(); 294 295 for(int i=0;i<4;i++) 296 test_point[i].add(test1.median); 297 //------ test three --------------------------------------------------------------/ 298 System.out.println("test 3: =========================="+ 299 "\nseperate the dataset into 30% testing part and 70% training part, " + 300 "\nusing the percentage to determine when to stop the learning,the max iteration is "+most); 301 points.clear(); 302 if(args.length==0) 303 points=Vector.read_data("dataset-2");//the dataset 304 else 305 points=Vector.read_data(args[0]); 306 307 int size=points.get(0).get_length(); 308 309 LinearClassifier lc=new LinearClassifier(size-1); 310 lc.median=Vector.get_sub_vector(Vector.vector_median(points), 0, size-2);; 311 lc.eta=0.4; 312 313 if(args.length>3) 314 lc.interval=new Integer(args[3]); 315 if(args.length>=3) 316 lc.eta=new Double(args[2]); 317 318 lc.get_vector(points,new Random()); 319 lc.train(); 320 lc.setpercent(); 321 System.out.println(" percentage: "+lc.percent+" test: "+lc.test_percent+" train: "+lc.train_percent); 322 323 int idtest[]=new int[lc.testpoints.length]; 324 for(int i=0;i<lc.testpoints.length;i++){ 325 if(lc.classify(lc.testpoints[i])) 326 idtest[i]=0; 327 else idtest[i]=1; 328 lc.testpoints[i].add(lc.median); 329 } 330 331 int idtrain[] =new int[lc.trainpoints.length]; 332 for(int i=0;i<lc.trainpoints.length;i++){ 333 if(lc.classify(lc.trainpoints[i])) 334 idtrain[i]=0; 335 else idtrain[i]=1; 336 lc.trainpoints[i].add(lc.median); 337 } 338 339 lc.V.add(lc.median); 340 lc.V.printvec(); 341 if(args.length<2) 342 { 343 Vector.write_data_withID("out-dataset", lc.testpoints, idtest); 344 Vector.write_data_withID("out-dataset", lc.trainpoints, idtrain,true); 345 } 346 else { 347 Vector.write_data_withID(args[1], lc.testpoints, idtest); 348 Vector.write_data_withID(args[1], lc.trainpoints, idtrain,true); 349 } 350 351 } 352 353 }