EI328 Final Project Review (III)
At the end of this project, we are required to implement a basic classifier on our own and make use of it in the Min-Max Modular Neural Network. Here we made a Logistic Regression Classifier with naive batch gradient descent optimization method.
The discriminant function is a sigmoid function as follow:
$y(\vec x)=\{1+exp\{-(\vec w^T\cdot\vec x+b) \}\}^{-1}$
The objective function is the negative logarithm of likelihood:
$-ln L=\sum_{n=1}^l ln\{1+exp\{-t_n\cdot(\vec w^T\cdot\vec x+b)\}\}$
The gradient can be calculated as follow:
$\nabla_{\vec w}(-ln L)=\sum_{n=1}^l\{y(\vec x_n)-\frac{t_n+1}{2}\}\cdot \vec x_n$
Here is the source code of ./src/DIY.java:
1 /** 2 * Argument Format: 3 * <task # = 6-2 or 6-3> <number of threads> <min group size> <max group size> <group size step> 4 **/ 5 6 import java.util.concurrent.atomic.*; 7 import java.util.concurrent.*; 8 import java.util.*; 9 import java.io.*; 10 11 class Node { 12 public int idx; 13 public double val; 14 15 public Node(int idx,double val) { 16 this.idx = idx; 17 this.val = val; 18 } 19 } 20 21 class DataSet { 22 /** This class mimic Problem class in LIBLINEAR **/ 23 public int l, n; 24 public List<List<Node>> x; 25 public List<Integer> y; 26 27 public DataSet(int l,int n) { 28 x = new ArrayList<List<Node>>(); 29 y = new ArrayList<Integer>(); 30 this.l = l; 31 this.n = n; 32 } 33 public DataSet(String path) { 34 x = new ArrayList<List<Node>>(); 35 y = new ArrayList<Integer>(); 36 StringTokenizer tok = null; 37 StringTokenizer tok1 = null; 38 Scanner in = null; 39 try { 40 in = new Scanner(new File(path)); 41 for (;in.hasNextLine();l++) { 42 List<Node> tmp = new ArrayList<Node>(); 43 tok = new StringTokenizer(in.nextLine()); 44 y.add(Integer.valueOf(tok.nextToken())); 45 while (tok.hasMoreTokens()) { 46 tok1 = new StringTokenizer(tok.nextToken(),":"); 47 int idx = Integer.parseInt(tok1.nextToken()); 48 double val = Double.parseDouble(tok1.nextToken()); 49 tmp.add(new Node(idx,val)); 50 if (idx>n) { 51 n = idx; 52 } 53 } 54 x.add(tmp); 55 } 56 in.close(); 57 } catch (Exception e) { 58 System.err.println("DATASET Error: "+e); 59 } 60 } 61 public int getY(int idx) { 62 if (idx<0||idx>=l) { 63 throw new RuntimeException("DataSet: IndexOutOfBound"); 64 } 65 return y.get(idx).intValue(); 66 } 67 } 68 69 class Augur { 70 /** This class mimic Model class in LIBLINEAR **/ 71 private static int id = 0; 72 private final double lamb = 0; 73 private DataSet prob; 74 private double[] w; 75 private double b; 76 private int n; 77 78 public Augur() {} 79 public Augur(DataSet prob) { 80 this.prob = prob; 81 n = prob.n; 82 w = new double[n]; 83 double[] g = new double[n+1]; 84 for (int itr=0;itr<1000;itr++) { 85 calGrad(g); 86 double len = norm(g); 87 if (itr%20==0) { 88 System.out.println("itr "+itr+": norm(g) = "+len); 89 } 90 //double step = optStep(g,len); 91 double step = len; 92 for (int i=0;i<n;i++) { 93 w[i] -= step*g[i]; 94 } 95 b -= step*g[n]; 96 } 97 System.out.println("\tGroup "+(++id)+" Done"); 98 } 99 public int predict(List<Node> x) { 100 return (sigmoid(x)>0.5)? 1:-1; 101 } 102 private double optStep(double[] g,double len) { 103 double[] val = new double[2]; 104 double p = 0, r = len; 105 double q1 = .382*r-.618*p; 106 double q2 = .618*r+.382*p; 107 val[0] = error(q1,g); 108 val[1] = error(q2,g); 109 for (int i=0;i<5;i++) { 110 if (val[0]<val[1]) { 111 r = q2; 112 q2 = q1; 113 q1 = .382*r-.618*p; 114 val[1] = val[0]; 115 val[0] = error(q1,g); 116 } else { 117 p = q1; 118 q1 = q2; 119 q2 = .618*r+.382*p; 120 val[0] = val[1]; 121 val[1] = error(q2,g); 122 } 123 } 124 return (val[0]<val[1])? q1:q2; 125 } 126 private double sigmoid(List<Node> x) { 127 double val = b; 128 for (Node node: x) { 129 if (node.idx<=n) { 130 val += w[node.idx-1]*node.val; 131 } 132 } 133 return 1./(1+Math.exp(-val)); 134 } 135 private double error(double step,double[] g) { 136 double val = .5*lamb*square(w,b); 137 for (int i=0;i<prob.l;i++) { 138 double tmp = b-step*g[n]; 139 for (Node node:prob.x.get(i)) { 140 tmp += (w[node.idx-1]-step*g[node.idx-1])*node.val; 141 } 142 val += Math.log(1+Math.exp(-prob.getY(i)*tmp)); 143 } 144 return val; 145 } 146 private void calGrad(double[] g) { 147 for (int i=0;i<n;i++) { 148 g[i] *= w[i]*lamb; 149 } 150 g[n] = b*lamb; 151 for (int i=0;i<prob.l;i++) { 152 double delt = sigmoid(prob.x.get(i))-(prob.getY(i)+1)/2; 153 for (Node node:prob.x.get(i)) { 154 g[node.idx-1] += node.val*delt; 155 } 156 g[prob.n] += delt; 157 } 158 } 159 private double square(double[] vect,double cst) { 160 double val = cst*cst; 161 for (int i=0;i<vect.length;i++) { 162 val += vect[i]*vect[i]; 163 } 164 return Math.sqrt(val); 165 } 166 private double norm(double[] vect) { 167 return Math.sqrt(square(vect,0)); 168 } 169 } 170 171 class AbstractDIY extends Thread { 172 protected static DataSet train, test; 173 protected static PrintWriter out; 174 protected static long start, end; 175 176 protected static void init() { 177 try { 178 System.out.println(); 179 if (!(new File("./data/new_train.txt")).exists()) { 180 System.out.println("Preprocessing: please wait a minute!"); 181 preproc("./data/","train.txt",""); 182 preproc("./data/","test.txt","5001:1.00"); 183 System.out.println("Ready!"); 184 } 185 System.out.println("Reading the files ...\tWait please ~ ~"); 186 train = new DataSet("./data/new_train.txt"); 187 test = new DataSet("./data/new_test.txt"); 188 } catch (Exception e) { 189 System.out.println("INIT Error: "+e); 190 } 191 } 192 protected static void print(String str) { 193 try { 194 System.out.print(str); 195 out.print(str); 196 } catch (Exception e) { 197 System.err.println("PRINT Error: "+e); 198 } 199 } 200 protected static void println(String str) { 201 try { 202 System.out.println(str); 203 out.println(str); 204 } catch (Exception e) { 205 System.err.println("PRINTLN Error: "+e); 206 } 207 } 208 protected static void printTime(long time,boolean train) { 209 try { 210 if (train) { 211 println("\tTraining:\t"+time+"ms elapsed"); 212 } else { 213 println("\tTesting:\t"+time+"ms elapsed"); 214 } 215 } catch (Exception e) { 216 System.err.println("PRINTTIME Error: "+e); 217 } 218 } 219 protected static void stats(int[] res) { 220 try { 221 int truePos=0, falsePos=0, falseNeg=0, trueNeg=0; 222 for (int i=0;i<test.l;i++) { 223 if (test.getY(i)>0) { 224 if (res[i]>0) { 225 truePos++; 226 } else { 227 falseNeg++; 228 } 229 } else { 230 if (res[i]>0) { 231 falsePos++; 232 } else { 233 trueNeg++; 234 } 235 } 236 } 237 System.out.println("\t"+truePos+"\t"+falsePos+"\t"+falseNeg+"\t"+trueNeg); 238 double acc = (truePos+trueNeg+.0)/test.l; 239 println("\tacc\t= "+acc); 240 double p = (truePos+.0)/(truePos+falsePos); 241 double r = (truePos+.0)/(truePos+falseNeg); 242 double f1 = 2*r*p/(r+p); 243 println("\tF1\t= "+f1); 244 double tpr = (truePos+.0)/(truePos+falseNeg); 245 double fpr = (falsePos+.0)/(falsePos+trueNeg); 246 println("\tTPR\t= "+tpr); 247 println("\tFPR\t= "+fpr); 248 println(""); 249 } catch (Exception e) { 250 System.err.println("STATS Error: "+e); 251 } 252 } 253 protected static void preproc(String dir,String filename,String tail) throws IOException { 254 Scanner fin = new Scanner(new FileInputStream(dir+filename)); 255 PrintWriter fout = new PrintWriter(new FileOutputStream(dir+"new_"+filename)); 256 int cnt=0, total=0; 257 while (fin.hasNextLine()) { 258 StringTokenizer tok = new StringTokenizer(fin.nextLine()); 259 if (tok.nextToken().charAt(0)=='A') { 260 fout.print("1 "); 261 cnt++; 262 } else { 263 fout.print("-1 "); 264 } 265 while (tok.hasMoreTokens()) { 266 fout.print(tok.nextToken()+" "); 267 } 268 fout.println(tail); 269 total++; 270 } 271 System.out.println(filename+":\t"+cnt+"/"+total); 272 fout.close(); 273 fin.close(); 274 } 275 } 276 277 278 public class DIY extends AbstractDIY { 279 private static int NUM; 280 private static int NUM_OF_THREAD; 281 282 private static Semaphore waitTask; 283 private static Semaphore nextTask; 284 private static Semaphore timeMutex; 285 private static Subprob probSentinel; 286 private static Augur modSentinel; 287 288 private static BlockingQueue<Subprob> probBuf; 289 private static BlockingQueue<Augur> modBuf; 290 private static BlockingQueue<Integer> minIdx; 291 292 private static AtomicInteger[][] minResult; 293 294 static { 295 try { 296 init(); 297 waitTask = new Semaphore(0); 298 nextTask = new Semaphore(0); 299 timeMutex = new Semaphore(1); 300 probSentinel = new Subprob(-1); 301 modSentinel = new Augur(); 302 probBuf = new LinkedBlockingQueue<Subprob>(); 303 modBuf = new LinkedBlockingQueue<Augur>(); 304 minIdx = new LinkedBlockingQueue<Integer>(); 305 } catch (Exception e) { 306 System.err.println("STATIC Error: "+e); 307 } 308 } 309 private static void separate(List<Integer> poslst,List<Integer> neglst,boolean prior) { 310 // List the positive samples and negative samples: 311 // Precondition: poslst and neglst are non-null empty lists 312 // Postcondition: poslst and neglst are filled with indices of positive 313 // and negative training data respectively 314 if (!prior) { 315 for (int i=0;i<train.l;i++) { 316 if (train.getY(i)>0) { 317 poslst.add(new Integer(i)); 318 } else { 319 neglst.add(new Integer(i)); 320 } 321 } 322 Random rand = new Random(); 323 Collections.shuffle(poslst,rand); 324 Collections.shuffle(neglst,rand); 325 } else { 326 try { 327 System.out.println("Gathering Prior Knowledge ... "); 328 BufferedReader pin = new BufferedReader(new FileReader("./data/train.txt")); 329 List<DataItem> posData = new ArrayList<DataItem>(); 330 List<DataItem> negData = new ArrayList<DataItem>(); 331 for (int i=0;i<train.l;i++) { 332 String line = pin.readLine(); 333 if (train.getY(i)>0) { 334 posData.add(new DataItem(i,line.substring(0,3))); 335 } else { 336 negData.add(new DataItem(i,line.substring(0,3))); 337 } 338 } 339 pin.close(); 340 //System.in.read(); 341 Collections.sort(posData); 342 Collections.sort(negData); 343 for (int i=0;i<posData.size();i++) { 344 poslst.add(posData.get(i).getValue()); 345 } 346 for (int i=0;i<negData.size();i++) { 347 neglst.add(negData.get(i).getValue()); 348 } 349 System.out.println("Ready!"); 350 } catch (Exception e) { 351 System.err.println("SEPARATE Error: "+e); 352 } 353 } 354 } 355 private static void distribute(boolean prior) { 356 List<Integer> poslst = new ArrayList<Integer>(); 357 List<Integer> neglst = new ArrayList<Integer>(); 358 separate(poslst,neglst,prior); 359 // Group the positive samples and negative samples: 360 int posGrpNum = (poslst.size()+NUM-1)/NUM; 361 int negGrpNum = (neglst.size()+NUM-1)/NUM; 362 int[] posGrps = new int[posGrpNum+1]; 363 int[] negGrps = new int[negGrpNum+1]; 364 System.out.println("\tTotally "+posGrpNum*negGrpNum+" groups:"); 365 for (int i=0;i<posGrpNum;i++) { 366 posGrps[i+1] = posGrps[i]+(poslst.size()+i)/posGrpNum; 367 } 368 for (int i=0;i<negGrpNum;i++) { 369 negGrps[i+1] = negGrps[i]+(neglst.size()+i)/negGrpNum; 370 } 371 try { // Add tasks to the buffer: 372 for (int i=0;i<posGrpNum;i++) { 373 for (int j=0;j<negGrpNum;j++) { 374 Subprob sub = new Subprob(i); 375 for (int k=posGrps[i];k<posGrps[i+1];k++) { 376 sub.add(poslst.get(k)); 377 } 378 for (int k=negGrps[j];k<negGrps[j+1];k++) { 379 sub.add(neglst.get(k)); 380 } 381 probBuf.put(sub); 382 } 383 } 384 } catch (Exception e) { 385 System.err.println("DISTRIBUTE Error 1: "+e); 386 } 387 try { // Prepare for the MIN modules: 388 minResult = new AtomicInteger[posGrpNum][test.l]; 389 for (int i=0;i<posGrpNum;i++) { 390 for (int j=0;j<test.l;j++) { 391 minResult[i][j] = new AtomicInteger(1); 392 } 393 } 394 probBuf.put(probSentinel); // "STOP TRAINING" signal 395 } catch (Exception e) { 396 System.err.println("DISTRIBUTE Error 2: "+e); 397 } 398 } 399 private static void main_thread(boolean prior) { 400 try { 401 start = System.currentTimeMillis(); // start training 402 end = -1; // still pending 403 DIY[] tsks = new DIY[NUM_OF_THREAD]; 404 for (int i=0;i<NUM_OF_THREAD;i++) { 405 tsks[i] = new DIY(); 406 tsks[i].start(); 407 } 408 distribute(prior); 409 for (int i=0;i<NUM_OF_THREAD;i++) { 410 waitTask.acquire(); // awaiting sub-threads 411 } 412 for (int i=0;i<NUM_OF_THREAD;i++) { 413 nextTask.release(); // testing enabled 414 } 415 start = System.currentTimeMillis(); // start testing 416 for (int i=0;i<NUM_OF_THREAD;i++) { 417 tsks[i].join(); 418 } 419 int[] res = new int[test.l]; 420 Arrays.fill(res,-1); 421 for (int i=0;i<minResult.length;i++) { 422 for (int j=0;j<test.l;j++) { 423 if (minResult[i][j].get()>0) { 424 res[j] = 1; // MAX Module 425 } 426 } 427 } 428 end = System.currentTimeMillis(); // finish testing 429 printTime(end-start,false); 430 stats(res); 431 } catch (Exception e) { 432 System.err.println("MAIN_THRAD Error: "+e); 433 } 434 } 435 public static void main(String[] args) { 436 try { 437 out = new PrintWriter(new FileWriter("./result/task"+args[0]+"_result.out")); 438 NUM_OF_THREAD = Integer.parseInt(args[1]); 439 int numMin = Integer.parseInt(args[2]); 440 int numMax = Integer.parseInt(args[3]); 441 int numStep = Integer.parseInt(args[4]); 442 for (NUM=numMin;NUM<=numMax;NUM+=numStep) { 443 print("Group Size = "+NUM+",\t"); 444 println(NUM_OF_THREAD+" Threads"); 445 probBuf.clear(); 446 modBuf.clear(); 447 minIdx.clear(); 448 if (args[0].equals("6-2")) { 449 main_thread(false); 450 } else { 451 main_thread(true); 452 } 453 } 454 out.close(); 455 } catch (Exception e) { 456 System.err.println("MAIN Error: "+e); 457 } 458 } 459 460 public void run() { 461 try { 462 train(); 463 waitTask.release(); 464 nextTask.acquire(); 465 test(); 466 } catch (Exception e) { 467 System.err.println("RUN Error: "+e); 468 } 469 } 470 private void train() { 471 Subprob sub = null; 472 try { 473 while (true) { 474 sub = probBuf.take(); 475 if (sub==probSentinel) { // signal of termination 476 timeMutex.acquire(); 477 if (end<0) { // finish training 478 end = System.currentTimeMillis(); 479 printTime(end-start,true); 480 } 481 timeMutex.release(); 482 probBuf.put(sub); 483 break; 484 } 485 DataSet prob = new DataSet(sub.size(),train.n); 486 for (int i=0;i<sub.size();i++) { 487 int pos = sub.getItem(i).intValue(); 488 prob.x.add(train.x.get(pos)); 489 prob.y.add(train.y.get(pos)); 490 } 491 modBuf.put(new Augur(prob)); 492 minIdx.put(sub.getIndex()); 493 } 494 modBuf.put(modSentinel); 495 } catch (Exception e) { 496 System.err.println("TRAIN Error: "+e); 497 } 498 } 499 private void test() { 500 Augur model = null; 501 int idx = -1; 502 try { 503 while (true) { 504 model = modBuf.take(); 505 if (model==modSentinel) { // signal of termination 506 modBuf.put(model); 507 break; 508 } 509 idx = minIdx.poll().intValue(); 510 for (int i=0;i<test.l;i++) { 511 if (model.predict(test.x.get(i))<0) { 512 minResult[idx][i].getAndSet(-1); // MIN Modules 513 } 514 } 515 } 516 } catch (Exception e) { 517 System.err.println("TEST Error:"+e); 518 } 519 } 520 }
The final result is quite reassuring in terms of F1 value and accuracy, whereas time performance leaves much to be desired: