EI328 Final Project Review (III)

 

  At the end of this project, we are required to implement a basic classifier on our own and make use of it in the Min-Max Modular Neural Network. Here we made a Logistic Regression Classifier with naive batch gradient descent optimization method.

  The discriminant function is a sigmoid function as follow:

      $y(\vec x)=\{1+exp\{-(\vec w^T\cdot\vec x+b) \}\}^{-1}$

  The objective function is the negative logarithm of likelihood:

      $-ln L=\sum_{n=1}^l ln\{1+exp\{-t_n\cdot(\vec w^T\cdot\vec x+b)\}\}$

  The gradient can be calculated as follow:

      $\nabla_{\vec w}(-ln L)=\sum_{n=1}^l\{y(\vec x_n)-\frac{t_n+1}{2}\}\cdot \vec x_n$

 

  Here is the source code of ./src/DIY.java:

  1 /**
  2 * Argument Format:
  3 * <task # = 6-2 or 6-3> <number of threads> <min group size> <max group size> <group size step>
  4 **/
  5 
  6 import java.util.concurrent.atomic.*;
  7 import java.util.concurrent.*;
  8 import java.util.*;
  9 import java.io.*;
 10 
 11 class Node {
 12     public int idx;
 13     public double val;
 14     
 15     public Node(int idx,double val) {
 16         this.idx = idx;
 17         this.val = val;
 18     }
 19 }
 20 
 21 class DataSet {
 22     /** This class mimic Problem class in LIBLINEAR **/
 23     public int l, n;
 24     public List<List<Node>> x;
 25     public List<Integer> y;
 26     
 27     public DataSet(int l,int n) {
 28         x = new ArrayList<List<Node>>();
 29         y = new ArrayList<Integer>();
 30         this.l = l;
 31         this.n = n;
 32     }
 33     public DataSet(String path) {
 34         x = new ArrayList<List<Node>>();
 35         y = new ArrayList<Integer>();
 36         StringTokenizer tok = null;
 37         StringTokenizer tok1 = null;
 38         Scanner in = null;
 39         try {
 40             in = new Scanner(new File(path));
 41             for (;in.hasNextLine();l++) {
 42                 List<Node> tmp = new ArrayList<Node>();
 43                 tok = new StringTokenizer(in.nextLine());
 44                 y.add(Integer.valueOf(tok.nextToken()));
 45                 while (tok.hasMoreTokens()) {
 46                     tok1 = new StringTokenizer(tok.nextToken(),":");
 47                     int idx = Integer.parseInt(tok1.nextToken());
 48                     double val = Double.parseDouble(tok1.nextToken());
 49                     tmp.add(new Node(idx,val));
 50                     if (idx>n) {
 51                         n = idx;
 52                     }
 53                 }
 54                 x.add(tmp);
 55             }
 56             in.close();
 57         } catch (Exception e) {
 58             System.err.println("DATASET Error: "+e);
 59         }
 60     }
 61     public int getY(int idx) {
 62         if (idx<0||idx>=l) {
 63             throw new RuntimeException("DataSet: IndexOutOfBound");
 64         }
 65         return y.get(idx).intValue();
 66     }
 67 }
 68 
 69 class Augur {
 70     /** This class mimic Model class in LIBLINEAR **/
 71     private static int id = 0;
 72     private final double lamb = 0;
 73     private DataSet prob;
 74     private double[] w;
 75     private double b;
 76     private int n;
 77     
 78     public Augur() {}
 79     public Augur(DataSet prob) {
 80         this.prob = prob;
 81         n = prob.n;
 82         w = new double[n];
 83         double[] g = new double[n+1];
 84         for (int itr=0;itr<1000;itr++) {
 85             calGrad(g);
 86             double len = norm(g);
 87             if (itr%20==0) {
 88                 System.out.println("itr "+itr+": norm(g) = "+len);
 89             }
 90             //double step = optStep(g,len);
 91             double step = len;
 92             for (int i=0;i<n;i++) {
 93                 w[i] -= step*g[i];
 94             }
 95             b -= step*g[n];
 96         }
 97         System.out.println("\tGroup "+(++id)+" Done");
 98     }
 99     public int predict(List<Node> x) {
100         return (sigmoid(x)>0.5)? 1:-1;
101     }
102     private double optStep(double[] g,double len) {
103         double[] val = new double[2];
104         double p = 0, r = len;
105         double q1 = .382*r-.618*p;
106         double q2 = .618*r+.382*p;
107         val[0] = error(q1,g);
108         val[1] = error(q2,g);
109         for (int i=0;i<5;i++) {
110             if (val[0]<val[1]) {
111                 r = q2;
112                 q2 = q1;
113                 q1 = .382*r-.618*p;
114                 val[1] = val[0];
115                 val[0] = error(q1,g);
116             } else {
117                 p = q1;
118                 q1 = q2;
119                 q2 = .618*r+.382*p;
120                 val[0] = val[1];
121                 val[1] = error(q2,g);
122             }
123         }
124         return (val[0]<val[1])? q1:q2;
125     }
126     private double sigmoid(List<Node> x) {
127         double val = b;
128         for (Node node: x) {
129             if (node.idx<=n) {
130                 val += w[node.idx-1]*node.val;
131             }
132         }
133         return 1./(1+Math.exp(-val));
134     }
135     private double error(double step,double[] g) {
136         double val = .5*lamb*square(w,b);
137         for (int i=0;i<prob.l;i++) {
138             double tmp = b-step*g[n];
139             for (Node node:prob.x.get(i)) {
140                 tmp += (w[node.idx-1]-step*g[node.idx-1])*node.val;
141             }
142             val += Math.log(1+Math.exp(-prob.getY(i)*tmp));
143         }
144         return val;
145     }
146     private void calGrad(double[] g) {
147         for (int i=0;i<n;i++) {
148             g[i] *= w[i]*lamb;
149         }
150         g[n] = b*lamb;
151         for (int i=0;i<prob.l;i++) {
152             double delt = sigmoid(prob.x.get(i))-(prob.getY(i)+1)/2;
153             for (Node node:prob.x.get(i)) {
154                 g[node.idx-1] += node.val*delt;
155             }
156             g[prob.n] += delt;
157         }
158     }
159     private double square(double[] vect,double cst) {
160         double val = cst*cst;
161         for (int i=0;i<vect.length;i++) {
162             val += vect[i]*vect[i];
163         }
164         return Math.sqrt(val);
165     }
166     private double norm(double[] vect) {
167         return Math.sqrt(square(vect,0));
168     }
169 }
170 
171 class AbstractDIY extends Thread {
172     protected static DataSet train, test;
173     protected static PrintWriter out;
174     protected static long start, end;
175     
176     protected static void init() {
177         try {
178             System.out.println();
179             if (!(new File("./data/new_train.txt")).exists()) {
180                 System.out.println("Preprocessing: please wait a minute!");
181                 preproc("./data/","train.txt","");
182                 preproc("./data/","test.txt","5001:1.00");
183                 System.out.println("Ready!");
184             }
185             System.out.println("Reading the files ...\tWait please ~ ~");
186             train = new DataSet("./data/new_train.txt");
187             test = new DataSet("./data/new_test.txt");
188         } catch (Exception e) {
189             System.out.println("INIT Error: "+e);
190         }
191     }
192     protected static void print(String str) {
193         try {
194             System.out.print(str);
195             out.print(str);
196         } catch (Exception e) {
197             System.err.println("PRINT Error: "+e);
198         }
199     }
200     protected static void println(String str) {
201         try {
202             System.out.println(str);
203             out.println(str);
204         } catch (Exception e) {
205             System.err.println("PRINTLN Error: "+e);
206         }
207     }
208     protected static void printTime(long time,boolean train) {
209         try {
210             if (train) {
211                 println("\tTraining:\t"+time+"ms elapsed");
212             } else {
213                 println("\tTesting:\t"+time+"ms elapsed");
214             }
215         } catch (Exception e) {
216             System.err.println("PRINTTIME Error: "+e);
217         }
218     }
219     protected static void stats(int[] res) {
220         try {
221             int truePos=0, falsePos=0, falseNeg=0, trueNeg=0;
222             for (int i=0;i<test.l;i++) {
223                 if (test.getY(i)>0) {
224                     if (res[i]>0) {
225                         truePos++;
226                     } else {
227                         falseNeg++;
228                     }
229                 } else {
230                     if (res[i]>0) {
231                         falsePos++;
232                     } else {
233                         trueNeg++;
234                     }
235                 }
236             }
237             System.out.println("\t"+truePos+"\t"+falsePos+"\t"+falseNeg+"\t"+trueNeg);
238             double acc = (truePos+trueNeg+.0)/test.l;
239             println("\tacc\t= "+acc);
240             double p = (truePos+.0)/(truePos+falsePos);
241             double r = (truePos+.0)/(truePos+falseNeg);
242             double f1 = 2*r*p/(r+p);
243             println("\tF1\t= "+f1);
244             double tpr = (truePos+.0)/(truePos+falseNeg);
245             double fpr = (falsePos+.0)/(falsePos+trueNeg);
246             println("\tTPR\t= "+tpr);
247             println("\tFPR\t= "+fpr);
248             println("");
249         } catch (Exception e) {
250             System.err.println("STATS Error: "+e);
251         }
252     }
253     protected static void preproc(String dir,String filename,String tail) throws IOException {
254         Scanner fin = new Scanner(new FileInputStream(dir+filename));
255         PrintWriter fout = new PrintWriter(new FileOutputStream(dir+"new_"+filename));
256         int cnt=0, total=0;
257         while (fin.hasNextLine()) {
258             StringTokenizer tok = new StringTokenizer(fin.nextLine());
259             if (tok.nextToken().charAt(0)=='A') {
260                 fout.print("1 ");
261                 cnt++;
262             } else {
263                 fout.print("-1 ");
264             }
265             while (tok.hasMoreTokens()) {
266                 fout.print(tok.nextToken()+" ");
267             }
268             fout.println(tail);
269             total++;
270         }
271         System.out.println(filename+":\t"+cnt+"/"+total);
272         fout.close();
273         fin.close();
274     }
275 }
276 
277 
278 public class DIY extends AbstractDIY {
279     private static int NUM;
280     private static int NUM_OF_THREAD;
281     
282     private static Semaphore waitTask;
283     private static Semaphore nextTask;
284     private static Semaphore timeMutex;
285     private static Subprob probSentinel;
286     private static Augur modSentinel;
287     
288     private static BlockingQueue<Subprob> probBuf;
289     private static BlockingQueue<Augur> modBuf;
290     private static BlockingQueue<Integer> minIdx;
291     
292     private static AtomicInteger[][] minResult;
293     
294     static {
295         try {
296             init();
297             waitTask = new Semaphore(0);
298             nextTask = new Semaphore(0);
299             timeMutex = new Semaphore(1);
300             probSentinel = new Subprob(-1);
301             modSentinel = new Augur();
302             probBuf = new LinkedBlockingQueue<Subprob>();
303             modBuf = new LinkedBlockingQueue<Augur>();
304             minIdx = new LinkedBlockingQueue<Integer>();
305         } catch (Exception e) {
306             System.err.println("STATIC Error: "+e);
307         }
308     }
309     private static void separate(List<Integer> poslst,List<Integer> neglst,boolean prior) {
310         // List the positive samples and negative samples:
311         // Precondition: poslst and neglst are non-null empty lists
312         // Postcondition: poslst and neglst are filled with indices of positive
313         //            and negative training data respectively
314         if (!prior) {
315             for (int i=0;i<train.l;i++) {
316                 if (train.getY(i)>0) {
317                     poslst.add(new Integer(i));
318                 } else {
319                     neglst.add(new Integer(i));
320                 }
321             }
322             Random rand = new Random();
323             Collections.shuffle(poslst,rand);
324             Collections.shuffle(neglst,rand);
325         } else {
326             try {
327                 System.out.println("Gathering Prior Knowledge ... ");
328                 BufferedReader pin = new BufferedReader(new FileReader("./data/train.txt"));
329                 List<DataItem> posData = new ArrayList<DataItem>();
330                 List<DataItem> negData = new ArrayList<DataItem>();
331                 for (int i=0;i<train.l;i++) {
332                     String line = pin.readLine();
333                     if (train.getY(i)>0) {
334                         posData.add(new DataItem(i,line.substring(0,3)));
335                     } else {
336                         negData.add(new DataItem(i,line.substring(0,3)));
337                     }
338                 }
339                 pin.close();
340                 //System.in.read();
341                 Collections.sort(posData);
342                 Collections.sort(negData);
343                 for (int i=0;i<posData.size();i++) {
344                     poslst.add(posData.get(i).getValue());
345                 }
346                 for (int i=0;i<negData.size();i++) {
347                     neglst.add(negData.get(i).getValue());
348                 }
349                 System.out.println("Ready!");
350             } catch (Exception e) {
351                 System.err.println("SEPARATE Error: "+e);
352             }
353         }
354     }
355     private static void distribute(boolean prior) {
356         List<Integer> poslst = new ArrayList<Integer>();
357         List<Integer> neglst = new ArrayList<Integer>();
358         separate(poslst,neglst,prior);
359         // Group the positive samples and negative samples:
360         int posGrpNum = (poslst.size()+NUM-1)/NUM;
361         int negGrpNum = (neglst.size()+NUM-1)/NUM;
362         int[] posGrps = new int[posGrpNum+1];
363         int[] negGrps = new int[negGrpNum+1];
364         System.out.println("\tTotally "+posGrpNum*negGrpNum+" groups:");
365         for (int i=0;i<posGrpNum;i++) {
366             posGrps[i+1] = posGrps[i]+(poslst.size()+i)/posGrpNum;
367         }
368         for (int i=0;i<negGrpNum;i++) {
369             negGrps[i+1] = negGrps[i]+(neglst.size()+i)/negGrpNum;
370         }
371         try {    // Add tasks to the buffer:
372             for (int i=0;i<posGrpNum;i++) {
373                 for (int j=0;j<negGrpNum;j++) {
374                     Subprob sub = new Subprob(i);
375                     for (int k=posGrps[i];k<posGrps[i+1];k++) {
376                         sub.add(poslst.get(k));
377                     }
378                     for (int k=negGrps[j];k<negGrps[j+1];k++) {
379                         sub.add(neglst.get(k));
380                     }
381                     probBuf.put(sub);
382                 }
383             }
384         } catch (Exception e) {
385             System.err.println("DISTRIBUTE Error 1: "+e);
386         }
387         try {    // Prepare for the MIN modules:
388             minResult = new AtomicInteger[posGrpNum][test.l];
389             for (int i=0;i<posGrpNum;i++) {
390                 for (int j=0;j<test.l;j++) {
391                     minResult[i][j] = new AtomicInteger(1);
392                 }
393             }
394             probBuf.put(probSentinel);    // "STOP TRAINING" signal
395         } catch (Exception e) {
396             System.err.println("DISTRIBUTE Error 2: "+e);
397         }
398     }
399     private static void main_thread(boolean prior) {
400         try {
401             start = System.currentTimeMillis();        // start training
402             end = -1;                                // still pending
403             DIY[] tsks = new DIY[NUM_OF_THREAD];
404             for (int i=0;i<NUM_OF_THREAD;i++) {
405                 tsks[i] = new DIY();
406                 tsks[i].start();
407             }
408             distribute(prior);
409             for (int i=0;i<NUM_OF_THREAD;i++) {
410                 waitTask.acquire();                    // awaiting sub-threads
411             }
412             for (int i=0;i<NUM_OF_THREAD;i++) {
413                 nextTask.release();                    // testing enabled
414             }
415             start = System.currentTimeMillis();        // start testing
416             for (int i=0;i<NUM_OF_THREAD;i++) {
417                 tsks[i].join();
418             }
419             int[] res = new int[test.l];
420             Arrays.fill(res,-1);
421             for (int i=0;i<minResult.length;i++) {
422                 for (int j=0;j<test.l;j++) {
423                     if (minResult[i][j].get()>0) {
424                         res[j] = 1;                    // MAX Module
425                     }
426                 }
427             }
428             end = System.currentTimeMillis();        // finish testing
429             printTime(end-start,false);
430             stats(res);
431         } catch (Exception e) {
432             System.err.println("MAIN_THRAD Error: "+e);
433         }
434     }
435     public static void main(String[] args) {
436         try {
437             out = new PrintWriter(new FileWriter("./result/task"+args[0]+"_result.out"));
438             NUM_OF_THREAD = Integer.parseInt(args[1]);
439             int numMin = Integer.parseInt(args[2]);
440             int numMax = Integer.parseInt(args[3]);
441             int numStep = Integer.parseInt(args[4]);
442             for (NUM=numMin;NUM<=numMax;NUM+=numStep) {
443                 print("Group Size = "+NUM+",\t");
444                 println(NUM_OF_THREAD+" Threads");
445                 probBuf.clear();
446                 modBuf.clear();
447                 minIdx.clear();
448                 if (args[0].equals("6-2")) {
449                     main_thread(false);
450                 } else {
451                     main_thread(true);
452                 }
453             }
454             out.close();
455         } catch (Exception e) {
456             System.err.println("MAIN Error: "+e);
457         }
458     }
459     
460     public void run() {
461         try {
462             train();
463             waitTask.release();
464             nextTask.acquire();
465             test();
466         } catch (Exception e) {
467             System.err.println("RUN Error: "+e);
468         }
469     }
470     private void train() {
471         Subprob sub = null;
472         try {
473             while (true) {
474                 sub = probBuf.take();
475                 if (sub==probSentinel) {    // signal of termination
476                     timeMutex.acquire();
477                     if (end<0) {            // finish training
478                         end = System.currentTimeMillis();
479                         printTime(end-start,true);
480                     }
481                     timeMutex.release();
482                     probBuf.put(sub);
483                     break;
484                 }
485                 DataSet prob = new DataSet(sub.size(),train.n);
486                 for (int i=0;i<sub.size();i++) {
487                     int pos = sub.getItem(i).intValue();
488                     prob.x.add(train.x.get(pos));
489                     prob.y.add(train.y.get(pos));
490                 }
491                 modBuf.put(new Augur(prob));
492                 minIdx.put(sub.getIndex());
493             }
494             modBuf.put(modSentinel);
495         } catch (Exception e) {
496             System.err.println("TRAIN Error: "+e);
497         }
498     }
499     private void test() {
500         Augur model = null;
501         int idx = -1;
502         try {
503             while (true) {
504                 model = modBuf.take();
505                 if (model==modSentinel) {    // signal of termination
506                     modBuf.put(model);
507                     break;
508                 }
509                 idx = minIdx.poll().intValue();
510                 for (int i=0;i<test.l;i++) {
511                     if (model.predict(test.x.get(i))<0) {
512                         minResult[idx][i].getAndSet(-1);    // MIN Modules
513                     }
514                 }
515             }
516         } catch (Exception e) {
517             System.err.println("TEST Error:"+e);
518         }
519     }
520 }

 

  The final result is quite reassuring in terms of F1 value and accuracy, whereas time performance leaves much to be desired:

 

posted on 2015-05-20 10:49  DevinZ  阅读(196)  评论(0编辑  收藏  举报

导航