机器学习:weka中Evaluation类源码解析及输出AUC及交叉验证介绍

  在机器学习分类结果的评估中,ROC曲线下的面积AOC是一个非常重要的指标。下面是调用weka类,输出AOC的源码:

try {
// 1.读入数据集

                Instances data = new Instances(
                                      new BufferedReader(
                                        new FileReader("E:\\Develop/Weka-3-6/data/contact-lenses.arff")));

                data.setClassIndex(data.numAttributes() - 1);

// 2.训练分类器并用十字交叉验证法来获得Evaluation对象
// 注意这里的方法与我们在上几节中使用的验证法是不同。
                Classifier cl = new NaiveBayes();
                Evaluation eval = new Evaluation(data);
                eval.crossValidateModel(cl, data, 10, new Random(1));

         
// 3.生成用于得到ROC曲面和AUC值的Instances对象
       System.out.println(eval.toClassDetailsString());
            System.out.println(eval.toSummaryString());
            System.out.println(eval.toMatrixString()); }
catch (Exception e) { e.printStackTrace(); }

 

  接着说一下交叉验证;

  如果没有分开训练集和测试集,可以使用Cross Validation方法,Evaluation中crossValidateModel方法的四个参数分别为,第一个是分类器,第二个是在某个数据集上评价的数据集,第三个参数是交叉检验的次数(10是比较常见的),第四个是一个随机数对象。

  注意:使用crossValidateModel时,分类器不需要先训练,否则buildClassifier方法会初始化分类器,交叉验证的配置结果就没有用了。

  类crossValidateModel的源码如下:

 public void crossValidateModel(Classifier classifier, Instances data,
    int numFolds, Random random, Object... forPredictionsPrinting)
    throws Exception {

    // Make a copy of the data we can reorder
    data = new Instances(data);
    data.randomize(random);
    if (data.classAttribute().isNominal()) {
      data.stratify(numFolds);
    }

    // We assume that the first element is a StringBuffer, the second a Range
    // (attributes
    // to output) and the third a Boolean (whether or not to output a
    // distribution instead
    // of just a classification)
    if (forPredictionsPrinting.length > 0) {
      // print the header first
      StringBuffer buff = (StringBuffer) forPredictionsPrinting[0];
      Range attsToOutput = (Range) forPredictionsPrinting[1];
      boolean printDist = ((Boolean) forPredictionsPrinting[2]).booleanValue();
      printClassificationsHeader(data, attsToOutput, printDist, buff);
    }

    // Do the folds
    for (int i = 0; i < numFolds; i++) {
      Instances train = data.trainCV(numFolds, i, random);
      setPriors(train);
      Classifier copiedClassifier = Classifier.makeCopy(classifier);
      copiedClassifier.buildClassifier(train);
      Instances test = data.testCV(numFolds, i);
      evaluateModel(copiedClassifier, test, forPredictionsPrinting);
    }
    m_NumFolds = numFolds;
  }

 

输出结果截图:

更新中。。。

 

 

 

libsvm 下载地址 https://github.com/cjlin1/libsvm

    github地址   https://github.com/cjlin1/libsvm

 

posted @ 2016-04-13 10:38  rongyux  阅读(2468)  评论(0编辑  收藏  举报