EI328 Final Project Review (I)
In this project, we use Java Liblinear library to solve a large-scale machine learning problem, which concerns a 2-class classification of Japanese patents on the section level (type A or not). The following picture depicts the structure of the programs of task 1-5.
The Makefile of the whole project is as follow:
1 all: ./bin/AbstractTask.class ./bin/Basic.class ./bin/MinMax.class ./bin/DIY.class 2 3 ./bin/AbstractTask.class: ./src/AbstractTask.java 4 javac -d ./bin/ -classpath ./bin/:./bin/liblinear-java-1.95.jar ./src/AbstractTask.java 5 6 ./bin/Basic.class: ./src/Basic.java ./src/AbstractTask.java 7 javac -d ./bin/ -classpath ./bin/:./bin/liblinear-java-1.95.jar ./src/Basic.java ./src/AbstractTask.java 8 9 ./bin/MinMax.class: ./src/MinMax.java ./src/AbstractTask.java 10 javac -d ./bin/ -classpath ./bin/:./bin/liblinear-java-1.95.jar ./src/MinMax.java ./src/AbstractTask.java 11 12 ./bin/DIY.class: ./src/DIY.java ./src/MinMax.java 13 javac -d ./bin/ -classpath ./bin/:./bin/liblinear-java-1.95.jar ./src/DIY.java ./src/MinMax.java 14 15 clean: 16 rm ./bin/*.class
To begin with, we shall create a basic class (./src/AbstractTask.java) to deal with file preprocessing and model processing:
1 import de.bwaldvogel.liblinear.*; 2 import java.util.concurrent.atomic.*; 3 import java.util.concurrent.*; 4 import java.util.*; 5 import java.io.*; 6 7 public class AbstractTask extends Thread { 8 protected static Parameter param; 9 protected static Problem train, test; 10 protected static FileWriter rout; 11 protected static PrintWriter out; 12 protected static long start, end; 13 14 protected static void init() { 15 try { 16 param = new Parameter(SolverType.L2R_L2LOSS_SVC_DUAL,1.0,0.01); 17 System.out.println(); 18 if (!(new File("./data/new_train.txt")).exists()) { 19 System.out.println("Preprocessing: please wait a minute!"); 20 preproc("./data/","train.txt",""); 21 preproc("./data/","test.txt","5001:1.00"); 22 System.out.println("Ready!"); 23 } 24 System.out.println("Reading the files ...\tWait please ~ ~"); 25 train = Train.readProblem(new File("./data/new_train.txt"),1); 26 test = Train.readProblem(new File("./data/new_test.txt"),1); 27 rout = new FileWriter("./result/ROC.out"); 28 rout.close(); 29 } catch (Exception e) { 30 System.out.println("INIT Error: "+e); 31 } 32 } 33 protected static void print(String str) { 34 try { 35 System.out.print(str); 36 out.print(str); 37 } catch (Exception e) { 38 System.err.println("PRINT Error: "+e); 39 } 40 } 41 protected static void println(String str) { 42 try { 43 System.out.println(str); 44 out.println(str); 45 } catch (Exception e) { 46 System.err.println("PRINTLN Error: "+e); 47 } 48 } 49 protected static void printTime(long time,boolean train) { 50 try { 51 if (train) { 52 println("\tTraining:\t"+time+"ms elapsed"); 53 } else { 54 println("\tTesting:\t"+time+"ms elapsed"); 55 } 56 } catch (Exception e) { 57 System.err.println("PRINTTIME Error: "+e); 58 } 59 } 60 protected static void rocWrtLine(String str) { 61 try { 62 rout = new FileWriter("./result/ROC.out",true); 63 rout.write(str+"\n"); 64 rout.close(); 65 } catch (Exception e) { 66 System.err.println("ROCWRTLINE Error: "+e); 67 } 68 } 69 protected static void stats(int[] res) { 70 try { 71 int truePos=0, falsePos=0, falseNeg=0, trueNeg=0; 72 for (int i=0;i<test.l;i++) { 73 if (test.y[i]>0) { 74 if (res[i]>0) { 75 truePos++; 76 } else { 77 falseNeg++; 78 } 79 } else { 80 if (res[i]>0) { 81 falsePos++; 82 } else { 83 trueNeg++; 84 } 85 } 86 } 87 //System.out.println("\t"+truePos+"\t"+falsePos+"\t"+falseNeg+"\t"+trueNeg); 88 double acc = (truePos+trueNeg+.0)/test.l; 89 println("\tacc\t= "+acc); 90 double p = (truePos+.0)/(truePos+falsePos); 91 double r = (truePos+.0)/(truePos+falseNeg); 92 double f1 = 2*r*p/(r+p); 93 println("\tF1\t= "+f1); 94 double tpr = (truePos+.0)/(truePos+falseNeg); 95 double fpr = (falsePos+.0)/(falsePos+trueNeg); 96 println("\tTPR\t= "+tpr); 97 println("\tFPR\t= "+fpr); 98 println(""); 99 rocWrtLine(fpr+" "+tpr); 100 } catch (Exception e) { 101 System.err.println("STATS Error: "+e); 102 } 103 } 104 protected static void preproc(String dir,String filename,String tail) throws IOException { 105 Scanner fin = new Scanner(new FileInputStream(dir+filename)); 106 PrintWriter fout = new PrintWriter(new FileOutputStream(dir+"new_"+filename)); 107 int cnt=0, total=0; 108 while (fin.hasNextLine()) { 109 StringTokenizer tok = new StringTokenizer(fin.nextLine()); 110 if (tok.nextToken().charAt(0)=='A') { 111 fout.print("1 "); 112 cnt++; 113 } else { 114 fout.print("-1 "); 115 } 116 while (tok.hasMoreTokens()) { 117 fout.print(tok.nextToken()+" "); 118 } 119 fout.println(tail); 120 total++; 121 } 122 System.out.println(filename+":\t"+cnt+"/"+total); 123 fout.close(); 124 fin.close(); 125 } 126 protected static Model modModify(Model model,double threshold) { 127 try { 128 model.save(new File("./data/min.txt")); 129 BufferedReader fin = new BufferedReader(new FileReader("./data/min.txt")); 130 PrintWriter fout = new PrintWriter(new FileWriter("./data/mout.txt")); 131 for (int i=0;i<3;i++) { 132 fout.println(fin.readLine()); 133 } 134 StringTokenizer tok = new StringTokenizer(fin.readLine()); 135 fout.print(tok.nextToken()+" "); 136 int num = Integer.parseInt(tok.nextToken())+1; 137 fout.println(num); 138 for (int i=0;i<num+1;i++) { 139 fout.println(fin.readLine()); 140 } 141 fout.println(-threshold+" "); 142 fout.println(fin.readLine()); 143 fout.close(); 144 fin.close(); 145 model = Model.load(new File("./data/mout.txt")); 146 } catch (Exception e) { 147 System.err.println("Error: "+e); 148 } 149 return model; 150 } 151 }
Now we give an example of directly using Liblinear in Java to solve the problem without parallelism (./src/Basic.java):
1 /** 2 * Command Example: 3 * java -classpath ./bin/:./bin/liblinear-java-1.95.jar Basic 1 -2.5 2.5 0.5 4 * Argument Format: 5 * <task # = 1> <min threshold> <max threshold> <threshold step> 6 **/ 7 8 import de.bwaldvogel.liblinear.*; 9 import java.util.*; 10 import java.io.*; 11 12 public class Basic extends AbstractTask { 13 public static void main_thread(double threshold) { 14 try { 15 start = System.currentTimeMillis(); 16 Model model = Linear.train(train, param); 17 end = System.currentTimeMillis(); 18 printTime(end-start,true); 19 model = modModify(model,threshold); 20 start = System.currentTimeMillis(); 21 int[] res = new int[test.l]; 22 Arrays.fill(res,-1); 23 for (int i=0;i<test.l;i++) { 24 if (Linear.predict(model,test.x[i])>0){ 25 res[i] = 1; 26 } 27 } 28 end = System.currentTimeMillis(); 29 printTime(end-start,false); 30 stats(res); 31 } catch (Exception e) { 32 System.err.println("MAIN_THREAD Error: "+e); 33 } 34 } 35 public static void main(String[] args) { 36 try { 37 init(); 38 out = new PrintWriter(new FileWriter("./result/task"+args[0]+"_result.out")); 39 double tmin = Double.parseDouble(args[1]); 40 double tmax = Double.parseDouble(args[2]); 41 double tstep = Double.parseDouble(args[3]); 42 for (double t=tmin;t<tmax+tstep/2;t+=tstep) { 43 println("Threshold = "+t); 44 main_thread(t); 45 } 46 rocWrtLine("_next_"); 47 rocWrtLine("_exit_"); 48 out.close(); 49 } catch (Exception e) { 50 System.err.println("MAIN Error: "+e); 51 } 52 } 53 }
In task 4, we are required to plot ROC curves of this classifier, so we wrote a python script (./src/ROC_plot) and do the classification in different threshold values.
1 #! /usr/bin/python 2 import pylab as pl 3 import string 4 5 if __name__=='__main__': 6 fin = open("./result/ROC.out",'r') 7 line = fin.readline() 8 while line[0:6]!='_exit_': 9 x = [] 10 y = [] 11 while line[0:6]!="_next_": 12 lst = line.split() 13 x.append(string.atof(lst[0])) 14 y.append(string.atof(lst[1])) 15 line = fin.readline() 16 pl.plot(x,y) 17 line = fin.readline() 18 fin.close(); 19 pl.title('ROC Curves') 20 pl.xlabel('False Positive Rate') 21 pl.ylabel('True Positive Rate') 22 pl.xlim(-0.1,1.1) 23 pl.ylim(0.0,1.1) 24 pl.show() 25
The ROC curves will be plotted only when an argument is passed to the shell script, such as "$ bash task1 1" and the like:
1 #! /bin/bash 2 reset 3 make 4 # arguments: <task # = 1> <min threshold> <max threshold> <threshold step> 5 java -classpath ./bin/:./bin/liblinear-java-1.95.jar Basic 1 -0.2 0.2 0.1 6 if (test $# -gt 0); then 7 python ./src/ROC_plot 8 fi