机器学习:weka中Evaluation类源码解析及输出AUC及交叉验证介绍
在机器学习分类结果的评估中,ROC曲线下的面积AOC是一个非常重要的指标。下面是调用weka类,输出AOC的源码:
try { // 1.读入数据集 Instances data = new Instances( new BufferedReader( new FileReader("E:\\Develop/Weka-3-6/data/contact-lenses.arff"))); data.setClassIndex(data.numAttributes() - 1); // 2.训练分类器并用十字交叉验证法来获得Evaluation对象 // 注意这里的方法与我们在上几节中使用的验证法是不同。 Classifier cl = new NaiveBayes(); Evaluation eval = new Evaluation(data); eval.crossValidateModel(cl, data, 10, new Random(1)); // 3.生成用于得到ROC曲面和AUC值的Instances对象
System.out.println(eval.toClassDetailsString());
System.out.println(eval.toSummaryString());
System.out.println(eval.toMatrixString()); } catch (Exception e) { e.printStackTrace(); }
接着说一下交叉验证;
如果没有分开训练集和测试集,可以使用Cross Validation方法,Evaluation中crossValidateModel方法的四个参数分别为,第一个是分类器,第二个是在某个数据集上评价的数据集,第三个参数是交叉检验的次数(10是比较常见的),第四个是一个随机数对象。
注意:使用crossValidateModel时,分类器不需要先训练,否则buildClassifier方法会初始化分类器,交叉验证的配置结果就没有用了。
类crossValidateModel的源码如下:
public void crossValidateModel(Classifier classifier, Instances data, int numFolds, Random random, Object... forPredictionsPrinting) throws Exception { // Make a copy of the data we can reorder data = new Instances(data); data.randomize(random); if (data.classAttribute().isNominal()) { data.stratify(numFolds); } // We assume that the first element is a StringBuffer, the second a Range // (attributes // to output) and the third a Boolean (whether or not to output a // distribution instead // of just a classification) if (forPredictionsPrinting.length > 0) { // print the header first StringBuffer buff = (StringBuffer) forPredictionsPrinting[0]; Range attsToOutput = (Range) forPredictionsPrinting[1]; boolean printDist = ((Boolean) forPredictionsPrinting[2]).booleanValue(); printClassificationsHeader(data, attsToOutput, printDist, buff); } // Do the folds for (int i = 0; i < numFolds; i++) { Instances train = data.trainCV(numFolds, i, random); setPriors(train); Classifier copiedClassifier = Classifier.makeCopy(classifier); copiedClassifier.buildClassifier(train); Instances test = data.testCV(numFolds, i); evaluateModel(copiedClassifier, test, forPredictionsPrinting); } m_NumFolds = numFolds; }
输出结果截图:
更新中。。。
libsvm 下载地址 https://github.com/cjlin1/libsvm
github地址 https://github.com/cjlin1/libsvm