代码改变世界

OpenCV码源笔记——Decision Tree决策树

2016-04-17 19:41  GarfieldEr007  阅读(737)  评论(0编辑  收藏  举报

来自OpenCV2.3.1 sample/c/mushroom.cpp

 

1.首先读入agaricus-lepiota.data的训练样本。

   样本中第一项是e或p代表有毒或无毒的标志位;其他是特征,可以把每个样本看做一个特征向量;

   cvSeqPush( seq, el_ptr );读入序列seq中,每一项都存储一个样本即特征向量;

   之后,把特征向量与标志位分别读入CvMat* data与CvMat* reponses中

   还有一个CvMat* missing保留丢失位当前小于0位置;

 

2.训练样本

[cpp] view plain copy
 
 print?
  1. dtree = new CvDTree;  
  2. dtree->train( data, CV_ROW_SAMPLE, responses, 0, 0, var_type, missing,  
  3. CvDTreeParams( 8, // max depth  
  4. 10, // min sample count 样本数小于10时,停止分裂   
  5. 0, // regression accuracy: N/A here;回归树的限制精度  
  6. true, // compute surrogate split, as we have missing data;;为真时,计算missing data和变量的重要性  
  7. 15, // max number of categories (use sub-optimal algorithm for larger numbers)类型上限以保证计算速度。树会以次优分裂(suboptimal split)的形式生长。只对2种取值以上的树有意义  
  8. 10, // the number of cross-validation folds;If cv_folds > 1 then prune a tree with K-fold cross-validation where K is equal to cv_folds  
  9. true, // use 1SE rule => smaller tree;If true 修剪树. 这将使树更紧凑,更能抵抗训练数据噪声,但有点不太准确  
  10. true, // throw away the pruned tree branches  
  11. priors //错分类的代价我们判断的:有毒VS无毒 错误的代价比 the array of priors, the bigger p_weight, the more attention  
  12. // to the poisonous mushrooms  
  13. // (a mushroom will be judjed to be poisonous with bigger chance)  
  14. ));  


 

3.

[cpp] view plain copy
 
 print?
  1. double r = dtree->predict( &sample, &mask )->value;//使用predict来预测样本,结果为 CvDTreeNode结构,dtree->predict(sample,mask)->value是分类情况下的类别或回归情况下的函数估计值;  


4.interactive_classification通过人工输入特征来判断。

 

[cpp] view plain copy
 
 print?
    1. #include "opencv2/core/core_c.h"  
    2. #include "opencv2/ml/ml.hpp"  
    3. #include <stdio.h>  
    4.   
    5. void help()  
    6. {  
    7.     printf("\nThis program demonstrated the use of OpenCV's decision tree function for learning and predicting data\n"  
    8.         "Usage :\n"  
    9.         "./mushroom <path to agaricus-lepiota.data>\n"  
    10.         "\n"  
    11.         "The sample demonstrates how to build a decision tree for classifying mushrooms.\n"  
    12.         "It uses the sample base agaricus-lepiota.data from UCI Repository, here is the link:\n"  
    13.         "\n"  
    14.         "Newman, D.J. & Hettich, S. & Blake, C.L. & Merz, C.J. (1998).\n"  
    15.         "UCI Repository of machine learning databases\n"  
    16.         "[http://www.ics.uci.edu/~mlearn/MLRepository.html].\n"  
    17.         "Irvine, CA: University of California, Department of Information and Computer Science.\n"  
    18.         "\n"  
    19.         "// loads the mushroom database, which is a text file, containing\n"  
    20.         "// one training sample per row, all the input variables and the output variable are categorical,\n"  
    21.         "// the values are encoded by characters.\n\n");  
    22. }  
    23.   
    24. int mushroom_read_database( const char* filename, CvMat** data, CvMat** missing, CvMat** responses )  
    25. {  
    26.     const int M = 1024;  
    27.     FILE* f = fopen( filename, "rt" );  
    28.     CvMemStorage* storage;  
    29.     CvSeq* seq;  
    30.     char buf[M+2], *ptr;  
    31.     float* el_ptr;  
    32.     CvSeqReader reader;  
    33.     int i, j, var_count = 0;  
    34.   
    35.     if( !f )  
    36.         return 0;  
    37.   
    38.     // read the first line and determine the number of variables  
    39.     if( !fgets( buf, M, f ))  
    40.     {  
    41.         fclose(f);  
    42.         return 0;  
    43.     }  
    44.   
    45.     for( ptr = buf; *ptr != '\0'; ptr++ )  
    46.         var_count += *ptr == ',';//计算每个样本的数量,每个样本一个“,”,样本数量=var_count+1;  
    47.     assert( ptr - buf == (var_count+1)*2 );  
    48.   
    49.     // create temporary memory storage to store the whole database  
    50.     //把样本存入seq中,存储空间是storage;  
    51.     el_ptr = new float[var_count+1];  
    52.     storage = cvCreateMemStorage();  
    53.     seq = cvCreateSeq( 0, sizeof(*seq), (var_count+1)*sizeof(float), storage );//  
    54.   
    55.     for(;;)  
    56.     {  
    57.         for( i = 0; i <= var_count; i++ )  
    58.         {  
    59.             int c = buf[i*2];  
    60.             el_ptr[i] = c == '?' ? -1.f : (float)c;  
    61.         }  
    62.         if( i != var_count+1 )  
    63.             break;  
    64.         cvSeqPush( seq, el_ptr );  
    65.         if( !fgets( buf, M, f ) || !strchr( buf, ',' ) )  
    66.             break;  
    67.     }  
    68.     fclose(f);  
    69.   
    70.     // allocate the output matrices and copy the base there  
    71.     *data = cvCreateMat( seq->total, var_count, CV_32F );//行数:样本数量;列数:样本大小;  
    72.     *missing = cvCreateMat( seq->total, var_count, CV_8U );  
    73.     *responses = cvCreateMat( seq->total, 1, CV_32F );//样本标志;  
    74.   
    75.     cvStartReadSeq( seq, &reader );  
    76.   
    77.     for( i = 0; i < seq->total; i++ )  
    78.     {  
    79.         const float* sdata = (float*)reader.ptr + 1;  
    80.         float* ddata = data[0]->data.fl + var_count*i;  
    81.         float* dr = responses[0]->data.fl + i;  
    82.         uchar* dm = missing[0]->data.ptr + var_count*i;  
    83.   
    84.         for( j = 0; j < var_count; j++ )  
    85.         {  
    86.             ddata[j] = sdata[j];  
    87.             dm[j] = sdata[j] < 0;  
    88.         }  
    89.         *dr = sdata[-1];//样本的第一个位置是标志;  
    90.         CV_NEXT_SEQ_ELEM( seq->elem_size, reader );  
    91.     }  
    92.   
    93.     cvReleaseMemStorage( &storage );  
    94.     delete el_ptr;  
    95.     return 1;  
    96. }  
    97.   
    98.   
    99. CvDTree* mushroom_create_dtree( const CvMat* data, const CvMat* missing,  
    100.     const CvMat* responses, float p_weight )  
    101. {  
    102.     CvDTree* dtree;  
    103.     CvMat* var_type;  
    104.     int i, hr1 = 0, hr2 = 0, p_total = 0;  
    105.     float priors[] = { 1, p_weight };  
    106.   
    107.     var_type = cvCreateMat( data->cols + 1, 1, CV_8U );  
    108.     cvSet( var_type, cvScalarAll(CV_VAR_CATEGORICAL) ); // all the variables are categorical  
    109.   
    110.     dtree = new CvDTree;  
    111.   
    112.     dtree->train( data, CV_ROW_SAMPLE, responses, 0, 0, var_type, missing,  
    113.         CvDTreeParams( 8, // max depth  
    114.         10, // min sample count样本数小于10时,停止分裂  
    115.         0, // regression accuracy: N/A here;回归树的限制精度  
    116.         true, // compute surrogate split, as we have missing data;为真时,计算missing data和可变的重要性正确度  
    117.         15, // max number of categories (use sub-optimal algorithm for larger numbers)类型上限以保证计算速度。树会以次优分裂(suboptimal split)的形式生长。只对2种取值以上的树有意义  
    118.         10, // the number of cross-validation folds;If cv_folds > 1 then prune a tree with K-fold cross-validation   
    119.         true, // use 1SE rule => smaller treeIf true 修剪树. 这将使树更紧凑,更能抵抗训练数据噪声,但有点不太准确  
    120.         true, // throw away the pruned tree branches  
    121.         priors // the array of priors, the bigger p_weight, the more attention  
    122.         // to the poisonous mushrooms  
    123.         // (a mushroom will be judjed to be poisonous with bigger chance)  
    124.         ));  
    125.   
    126.     // compute hit-rate on the training database, demonstrates predict usage.  
    127.       
    128.     for( i = 0; i < data->rows; i++ )  
    129.     {  
    130.         CvMat sample, mask;  
    131.         cvGetRow( data, &sample, i );  
    132.         cvGetRow( missing, &mask, i );  
    133.         double r = dtree->predict( &sample, &mask )->value;//使用predict来预测样本,结果为 CvDTreeNode结构,dtree->predict(sample,mask)->value是分类情况下的类别或回归情况下的函数估计值;  
    134.         int d = fabs(r - responses->data.fl[i]) >= FLT_EPSILON;//大于阈值FLT_EPSILON被判断为误检  
    135.         if( d )  
    136.         {  
    137.             if( r != 'p' )  
    138.                 hr1++;  
    139.             else  
    140.                 hr2++;  
    141.         }  
    142.         p_total += responses->data.fl[i] == 'p';  
    143.     }  
    144.   
    145.     printf( "Results on the training database:\n"  
    146.         "\tPoisonous mushrooms mis-predicted: %d (%g%%)\n"  
    147.         "\tFalse-alarms: %d (%g%%)\n", hr1, (double)hr1*100/p_total,  
    148.         hr2, (double)hr2*100/(data->rows - p_total) );  
    149.   
    150.     cvReleaseMat( &var_type );  
    151.   
    152.     return dtree;  
    153. }  
    154.   
    155.   
    156. static const char* var_desc[] =  
    157. {  
    158.     "cap shape (bell=b,conical=c,convex=x,flat=f)",  
    159.     "cap surface (fibrous=f,grooves=g,scaly=y,smooth=s)",  
    160.     "cap color (brown=n,buff=b,cinnamon=c,gray=g,green=r,\n\tpink=p,purple=u,red=e,white=w,yellow=y)",  
    161.     "bruises? (bruises=t,no=f)",  
    162.     "odor (almond=a,anise=l,creosote=c,fishy=y,foul=f,\n\tmusty=m,none=n,pungent=p,spicy=s)",  
    163.     "gill attachment (attached=a,descending=d,free=f,notched=n)",  
    164.     "gill spacing (close=c,crowded=w,distant=d)",  
    165.     "gill size (broad=b,narrow=n)",  
    166.     "gill color (black=k,brown=n,buff=b,chocolate=h,gray=g,\n\tgreen=r,orange=o,pink=p,purple=u,red=e,white=w,yellow=y)",  
    167.     "stalk shape (enlarging=e,tapering=t)",  
    168.     "stalk root (bulbous=b,club=c,cup=u,equal=e,rhizomorphs=z,rooted=r)",  
    169.     "stalk surface above ring (ibrous=f,scaly=y,silky=k,smooth=s)",  
    170.     "stalk surface below ring (ibrous=f,scaly=y,silky=k,smooth=s)",  
    171.     "stalk color above ring (brown=n,buff=b,cinnamon=c,gray=g,orange=o,\n\tpink=p,red=e,white=w,yellow=y)",  
    172.     "stalk color below ring (brown=n,buff=b,cinnamon=c,gray=g,orange=o,\n\tpink=p,red=e,white=w,yellow=y)",  
    173.     "veil type (partial=p,universal=u)",  
    174.     "veil color (brown=n,orange=o,white=w,yellow=y)",  
    175.     "ring number (none=n,one=o,two=t)",  
    176.     "ring type (cobwebby=c,evanescent=e,flaring=f,large=l,\n\tnone=n,pendant=p,sheathing=s,zone=z)",  
    177.     "spore print color (black=k,brown=n,buff=b,chocolate=h,green=r,\n\torange=o,purple=u,white=w,yellow=y)",  
    178.     "population (abundant=a,clustered=c,numerous=n,\n\tscattered=s,several=v,solitary=y)",  
    179.     "habitat (grasses=g,leaves=l,meadows=m,paths=p\n\turban=u,waste=w,woods=d)",  
    180.     0  
    181. };  
    182.   
    183.   
    184. void print_variable_importance( CvDTree* dtree, const char** var_desc )  
    185. {  
    186.     const CvMat* var_importance = dtree->get_var_importance();  
    187.     int i;  
    188.     char input[1000];  
    189.   
    190.     if( !var_importance )  
    191.     {  
    192.         printf( "Error: Variable importance can not be retrieved\n" );  
    193.         return;  
    194.     }  
    195.   
    196.     printf( "Print variable importance information? (y/n) " );  
    197.     scanf( "%1s", input );  
    198.     if( input[0] != 'y' && input[0] != 'Y' )  
    199.         return;  
    200.   
    201.     for( i = 0; i < var_importance->cols*var_importance->rows; i++ )  
    202.     {  
    203.         double val = var_importance->data.db[i];  
    204.         if( var_desc )  
    205.         {  
    206.             char buf[100];  
    207.             int len = strchr( var_desc[i], '(' ) - var_desc[i] - 1;  
    208.             strncpy( buf, var_desc[i], len );  
    209.             buf[len] = '\0';  
    210.             printf( "%s", buf );  
    211.         }  
    212.         else  
    213.             printf( "var #%d", i );  
    214.         printf( ": %g%%\n", val*100. );  
    215.     }  
    216. }  
    217.   
    218. void interactive_classification( CvDTree* dtree, const char** var_desc )  
    219. {  
    220.     char input[1000];  
    221.     const CvDTreeNode* root;  
    222.     CvDTreeTrainData* data;  
    223.   
    224.     if( !dtree )  
    225.         return;  
    226.   
    227.     root = dtree->get_root();  
    228.     data = dtree->get_data();  
    229.   
    230.     for(;;)  
    231.     {  
    232.         const CvDTreeNode* node;  
    233.   
    234.         printf( "Start/Proceed with interactive mushroom classification (y/n): " );  
    235.         scanf( "%1s", input );  
    236.         if( input[0] != 'y' && input[0] != 'Y' )  
    237.             break;  
    238.         printf( "Enter 1-letter answers, '?' for missing/unknown value...\n" );   
    239.   
    240.         // custom version of predict  
    241.         //传统的预测方式;  
    242.         node = root;  
    243.         for(;;)  
    244.         {  
    245.             CvDTreeSplit* split = node->split;  
    246.             int dir = 0;  
    247.   
    248.             if( !node->left || node->Tn <= dtree->get_pruned_tree_idx() || !node->split )  
    249.                 break;  
    250.   
    251.             for( ; split != 0; )  
    252.             {  
    253.                 int vi = split->var_idx, j;  
    254.                 int count = data->cat_count->data.i[vi];  
    255.                 const int* map = data->cat_map->data.i + data->cat_ofs->data.i[vi];  
    256.   
    257.                 printf( "%s: ", var_desc[vi] );  
    258.                 scanf( "%1s", input );  
    259.   
    260.                 if( input[0] == '?' )  
    261.                 {  
    262.                     split = split->next;  
    263.                     continue;  
    264.                 }  
    265.   
    266.                 // convert the input character to the normalized value of the variable  
    267.                 for( j = 0; j < count; j++ )  
    268.                     if( map[j] == input[0] )  
    269.                         break;  
    270.                 if( j < count )  
    271.                 {  
    272.                     dir = (split->subset[j>>5] & (1 << (j&31))) ? -1 : 1;  
    273.                     if( split->inversed )  
    274.                         dir = -dir;  
    275.                     break;  
    276.                 }  
    277.                 else  
    278.                     printf( "Error: unrecognized value\n" );  
    279.             }  
    280.   
    281.             if( !dir )  
    282.             {  
    283.                 printf( "Impossible to classify the sample\n");  
    284.                 node = 0;  
    285.                 break;  
    286.             }  
    287.             node = dir < 0 ? node->left : node->right;  
    288.         }  
    289.   
    290.         if( node )  
    291.             printf( "Prediction result: the mushroom is %s\n",  
    292.             node->class_idx == 0 ? "EDIBLE" : "POISONOUS" );  
    293.         printf( "\n-----------------------------\n" );  
    294.     }  
    295. }  
    296.   
    297.   
    298. int main( int argc, char** argv )  
    299. {  
    300.     CvMat *data = 0, *missing = 0, *responses = 0;  
    301.     CvDTree* dtree;  
    302.     const char* base_path = argc >= 2 ? argv[1] : "agaricus-lepiota.data";  
    303.   
    304.     help();  
    305.   
    306.     if( !mushroom_read_database( base_path, &data, &missing, &responses ) )  
    307.     {  
    308.         printf( "\nUnable to load the training database\n\n");  
    309.         help();  
    310.         return -1;  
    311.     }  
    312.   
    313.     dtree = mushroom_create_dtree( data, missing, responses,  
    314.         10 // poisonous mushrooms will have 10x higher weight in the decision tree  
    315.         );  
    316.     cvReleaseMat( &data );  
    317.     cvReleaseMat( &missing );  
    318.     cvReleaseMat( &responses );  
    319.   
    320.     print_variable_importance( dtree, var_desc );  
    321.     interactive_classification( dtree, var_desc );  
    322.     delete dtree;  
    323.   
    324.     return 0;  
    325. }  
    326. //from: http://blog.csdn.net/yangtrees/article/details/7490852