EI328 Final Project Review (I)

 

  In this project, we use Java Liblinear library to solve a large-scale machine learning problem, which concerns a 2-class classification of Japanese patents on the section level (type A or not). The following picture depicts the structure of the programs of task 1-5.

 

  The Makefile of the whole project is as follow:

 1 all: ./bin/AbstractTask.class ./bin/Basic.class ./bin/MinMax.class ./bin/DIY.class
 2 
 3 ./bin/AbstractTask.class: ./src/AbstractTask.java
 4     javac -d ./bin/ -classpath ./bin/:./bin/liblinear-java-1.95.jar ./src/AbstractTask.java
 5 
 6 ./bin/Basic.class: ./src/Basic.java ./src/AbstractTask.java
 7     javac -d ./bin/ -classpath ./bin/:./bin/liblinear-java-1.95.jar ./src/Basic.java ./src/AbstractTask.java
 8 
 9 ./bin/MinMax.class: ./src/MinMax.java ./src/AbstractTask.java
10     javac -d ./bin/ -classpath ./bin/:./bin/liblinear-java-1.95.jar ./src/MinMax.java ./src/AbstractTask.java
11 
12 ./bin/DIY.class: ./src/DIY.java ./src/MinMax.java
13     javac -d ./bin/ -classpath ./bin/:./bin/liblinear-java-1.95.jar ./src/DIY.java ./src/MinMax.java
14 
15 clean:
16     rm ./bin/*.class

 

  To begin with, we shall create a basic class (./src/AbstractTask.java) to deal with file preprocessing and model processing:

  1 import de.bwaldvogel.liblinear.*;
  2 import java.util.concurrent.atomic.*;
  3 import java.util.concurrent.*;
  4 import java.util.*;
  5 import java.io.*;
  6 
  7 public class AbstractTask extends Thread {
  8     protected static Parameter param;
  9     protected static Problem train, test;
 10     protected static FileWriter rout;
 11     protected static PrintWriter out;
 12     protected static long start, end;
 13     
 14     protected static void init() {
 15         try {
 16             param = new Parameter(SolverType.L2R_L2LOSS_SVC_DUAL,1.0,0.01);
 17             System.out.println();
 18             if (!(new File("./data/new_train.txt")).exists()) {
 19                 System.out.println("Preprocessing: please wait a minute!");
 20                 preproc("./data/","train.txt","");
 21                 preproc("./data/","test.txt","5001:1.00");
 22                 System.out.println("Ready!");
 23             }
 24             System.out.println("Reading the files ...\tWait please ~ ~");
 25             train = Train.readProblem(new File("./data/new_train.txt"),1);
 26             test = Train.readProblem(new File("./data/new_test.txt"),1);
 27             rout = new FileWriter("./result/ROC.out");
 28             rout.close();
 29         } catch (Exception e) {
 30             System.out.println("INIT Error: "+e);
 31         }
 32     }
 33     protected static void print(String str) {
 34         try {
 35             System.out.print(str);
 36             out.print(str);
 37         } catch (Exception e) {
 38             System.err.println("PRINT Error: "+e);
 39         }
 40     }
 41     protected static void println(String str) {
 42         try {
 43             System.out.println(str);
 44             out.println(str);
 45         } catch (Exception e) {
 46             System.err.println("PRINTLN Error: "+e);
 47         }
 48     }
 49     protected static void printTime(long time,boolean train) {
 50         try {
 51             if (train) {
 52                 println("\tTraining:\t"+time+"ms elapsed");
 53             } else {
 54                 println("\tTesting:\t"+time+"ms elapsed");
 55             }
 56         } catch (Exception e) {
 57             System.err.println("PRINTTIME Error: "+e);
 58         }
 59     }
 60     protected static void rocWrtLine(String str) {
 61         try {
 62             rout = new FileWriter("./result/ROC.out",true);
 63             rout.write(str+"\n");
 64             rout.close();
 65         } catch (Exception e) {
 66             System.err.println("ROCWRTLINE Error: "+e);
 67         }
 68     }
 69     protected static void stats(int[] res) {
 70         try {
 71             int truePos=0, falsePos=0, falseNeg=0, trueNeg=0;
 72             for (int i=0;i<test.l;i++) {
 73                 if (test.y[i]>0) {
 74                     if (res[i]>0) {
 75                         truePos++;
 76                     } else {
 77                         falseNeg++;
 78                     }
 79                 } else {
 80                     if (res[i]>0) {
 81                         falsePos++;
 82                     } else {
 83                         trueNeg++;
 84                     }
 85                 }
 86             }
 87             //System.out.println("\t"+truePos+"\t"+falsePos+"\t"+falseNeg+"\t"+trueNeg);
 88             double acc = (truePos+trueNeg+.0)/test.l;
 89             println("\tacc\t= "+acc);
 90             double p = (truePos+.0)/(truePos+falsePos);
 91             double r = (truePos+.0)/(truePos+falseNeg);
 92             double f1 = 2*r*p/(r+p);
 93             println("\tF1\t= "+f1);
 94             double tpr = (truePos+.0)/(truePos+falseNeg);
 95             double fpr = (falsePos+.0)/(falsePos+trueNeg);
 96             println("\tTPR\t= "+tpr);
 97             println("\tFPR\t= "+fpr);
 98             println("");
 99             rocWrtLine(fpr+" "+tpr);
100         } catch (Exception e) {
101             System.err.println("STATS Error: "+e);
102         }
103     }
104     protected static void preproc(String dir,String filename,String tail) throws IOException {
105         Scanner fin = new Scanner(new FileInputStream(dir+filename));
106         PrintWriter fout = new PrintWriter(new FileOutputStream(dir+"new_"+filename));
107         int cnt=0, total=0;
108         while (fin.hasNextLine()) {
109             StringTokenizer tok = new StringTokenizer(fin.nextLine());
110             if (tok.nextToken().charAt(0)=='A') {
111                 fout.print("1 ");
112                 cnt++;
113             } else {
114                 fout.print("-1 ");
115             }
116             while (tok.hasMoreTokens()) {
117                 fout.print(tok.nextToken()+" ");
118             }
119             fout.println(tail);
120             total++;
121         }
122         System.out.println(filename+":\t"+cnt+"/"+total);
123         fout.close();
124         fin.close();
125     }
126     protected static Model modModify(Model model,double threshold) {
127         try {
128             model.save(new File("./data/min.txt"));
129             BufferedReader fin = new BufferedReader(new FileReader("./data/min.txt"));
130             PrintWriter fout = new PrintWriter(new FileWriter("./data/mout.txt"));
131             for (int i=0;i<3;i++) {
132                 fout.println(fin.readLine());
133             }
134             StringTokenizer tok = new StringTokenizer(fin.readLine());
135             fout.print(tok.nextToken()+" ");
136             int num = Integer.parseInt(tok.nextToken())+1;
137             fout.println(num);
138             for (int i=0;i<num+1;i++) {
139                 fout.println(fin.readLine());
140             }
141             fout.println(-threshold+" ");
142             fout.println(fin.readLine());
143             fout.close();
144             fin.close();
145             model = Model.load(new File("./data/mout.txt"));
146         } catch (Exception e) {
147             System.err.println("Error: "+e);
148         }
149         return model;
150     }
151 }

 

  Now we give an example of directly using Liblinear in Java to solve the problem without parallelism (./src/Basic.java):

 1 /**
 2 * Command Example:
 3 * java -classpath ./bin/:./bin/liblinear-java-1.95.jar Basic 1 -2.5 2.5 0.5
 4 * Argument Format:
 5 * <task # = 1> <min threshold> <max threshold> <threshold step>
 6 **/
 7 
 8 import de.bwaldvogel.liblinear.*;
 9 import java.util.*;
10 import java.io.*;
11 
12 public class Basic extends AbstractTask {
13     public static void main_thread(double threshold) {
14         try {
15             start = System.currentTimeMillis();
16             Model model = Linear.train(train, param);
17             end = System.currentTimeMillis();
18             printTime(end-start,true);
19             model = modModify(model,threshold);
20             start = System.currentTimeMillis();
21             int[] res = new int[test.l];
22             Arrays.fill(res,-1);
23             for (int i=0;i<test.l;i++) {
24                 if (Linear.predict(model,test.x[i])>0){
25                     res[i] = 1;
26                 }
27             }
28             end = System.currentTimeMillis();
29             printTime(end-start,false);
30             stats(res);
31         } catch (Exception e) {
32             System.err.println("MAIN_THREAD Error: "+e);
33         }
34     }
35     public static void main(String[] args) {
36         try {
37             init();
38             out = new PrintWriter(new FileWriter("./result/task"+args[0]+"_result.out"));
39             double tmin = Double.parseDouble(args[1]);
40             double tmax = Double.parseDouble(args[2]);
41             double tstep = Double.parseDouble(args[3]);
42             for (double t=tmin;t<tmax+tstep/2;t+=tstep) {
43                 println("Threshold = "+t);
44                 main_thread(t);
45             }
46             rocWrtLine("_next_");
47             rocWrtLine("_exit_");
48             out.close();
49         } catch (Exception e) {
50             System.err.println("MAIN Error: "+e);
51         }
52     }
53 }

 

  In task 4, we are required to plot ROC curves of this classifier, so we wrote a python script (./src/ROC_plot) and do the classification in different threshold values.

 1 #! /usr/bin/python
 2 import pylab as pl
 3 import string
 4 
 5 if __name__=='__main__':
 6     fin = open("./result/ROC.out",'r')
 7     line = fin.readline()
 8     while line[0:6]!='_exit_':
 9         x = []
10         y = []
11         while line[0:6]!="_next_":
12             lst = line.split()
13             x.append(string.atof(lst[0]))
14             y.append(string.atof(lst[1]))
15             line = fin.readline()
16         pl.plot(x,y)
17         line = fin.readline()
18     fin.close();
19     pl.title('ROC Curves')
20     pl.xlabel('False Positive Rate')
21     pl.ylabel('True Positive Rate')
22     pl.xlim(-0.1,1.1)
23     pl.ylim(0.0,1.1)
24     pl.show()
25     

  The ROC curves will be plotted only when an argument is passed to the shell script, such as "$ bash task1 1" and the like:

1 #! /bin/bash
2 reset
3 make
4 # arguments: <task # = 1> <min threshold> <max threshold> <threshold step>
5 java -classpath ./bin/:./bin/liblinear-java-1.95.jar Basic 1 -0.2 0.2 0.1
6 if (test $# -gt 0); then
7     python ./src/ROC_plot
8 fi

 

 

posted on 2015-05-20 10:29  DevinZ  阅读(220)  评论(0编辑  收藏  举报

导航