数据挖掘-决策树ID3分类算法的C++实现

转自:http://blog.csdn.net/yangliuy/article/details/7322015

 

作者: yangliuy

决策树算法是非常常用的分类算法,是逼近离散目标函数的方法,学习得到的函数以决策树的形式表示。其基本思路是不断选取产生信息增益最大的属性来划分样例 集和,构造决策树。信息增益定义为结点与其子结点的信息熵之差。信息熵是香农提出的,用于描述信息不纯度(不稳定性),其计算公式是


Pi为子集合中不同性(而二元分类即正样例和负样例)的样例的比例。这样信息收益可以定义为样本按照某属性划分时造成熵减少的期望,可以区分训练样本中正负样本的能力,其计算公司是


我实现该算法针对的样例集合如下


该表记录了在不同气候条件下是否去打球的情况,要求根据该表用程序输出决策树

C++代码如下,程序中有详细注释

 

[cpp] view plaincopy
 
  1. #include <iostream>  
  2. #include <string>  
  3. #include <vector>  
  4. #include <map>  
  5. #include <algorithm>  
  6. #include <cmath>  
  7. using namespace std;  
  8. #define MAXLEN 6//输入每行的数据个数  
  9.   
  10. //多叉树的实现   
  11. //1 广义表  
  12. //2 父指针表示法,适于经常找父结点的应用  
  13. //3 子女链表示法,适于经常找子结点的应用  
  14. //4 左长子,右兄弟表示法,实现比较麻烦  
  15. //5 每个结点的所有孩子用vector保存  
  16. //教训:数据结构的设计很重要,本算法采用5比较合适,同时  
  17. //注意维护剩余样例和剩余属性信息,建树时横向遍历考循环属性的值,  
  18. //纵向遍历靠递归调用  
  19.   
  20. vector <vector <string> > state;//实例集  
  21. vector <string> item(MAXLEN);//对应一行实例集  
  22. vector <string> attribute_row;//保存首行即属性行数据  
  23. string end("end");//输入结束  
  24. string yes("yes");  
  25. string no("no");  
  26. string blank("");  
  27. map<string,vector < string > > map_attribute_values;//存储属性对应的所有的值  
  28. int tree_size = 0;  
  29. struct Node{//决策树节点  
  30.     string attribute;//属性值  
  31.     string arrived_value;//到达的属性值  
  32.     vector<Node *> childs;//所有的孩子  
  33.     Node(){  
  34.         attribute = blank;  
  35.         arrived_value = blank;  
  36.     }  
  37. };  
  38. Node * root;  
  39.   
  40. //根据数据实例计算属性与值组成的map  
  41. void ComputeMapFrom2DVector(){  
  42.     unsigned int i,j,k;  
  43.     bool exited = false;  
  44.     vector<string> values;  
  45.     for(i = 1; i < MAXLEN-1; i++){//按照列遍历  
  46.         for (j = 1; j < state.size(); j++){  
  47.             for (k = 0; k < values.size(); k++){  
  48.                 if(!values[k].compare(state[j][i])) exited = true;  
  49.             }  
  50.             if(!exited){  
  51.                 values.push_back(state[j][i]);//注意Vector的插入都是从前面插入的,注意更新it,始终指向vector头  
  52.             }  
  53.             exited = false;  
  54.         }  
  55.         map_attribute_values[state[0][i]] = values;  
  56.         values.erase(values.begin(), values.end());  
  57.     }     
  58. }  
  59.   
  60. //根据具体属性和值来计算熵  
  61. double ComputeEntropy(vector <vector <string> > remain_state, string attribute, string value,bool ifparent){  
  62.     vector<int> count (2,0);  
  63.     unsigned int i,j;  
  64.     bool done_flag = false;//哨兵值  
  65.     for(j = 1; j < MAXLEN; j++){  
  66.         if(done_flag) break;  
  67.         if(!attribute_row[j].compare(attribute)){  
  68.             for(i = 1; i < remain_state.size(); i++){  
  69.                 if((!ifparent&&!remain_state[i][j].compare(value)) || ifparent){//ifparent记录是否算父节点  
  70.                     if(!remain_state[i][MAXLEN - 1].compare(yes)){  
  71.                         count[0]++;  
  72.                     }  
  73.                     else count[1]++;  
  74.                 }  
  75.             }  
  76.             done_flag = true;  
  77.         }  
  78.     }  
  79.     if(count[0] == 0 || count[1] == 0 ) return 0;//全部是正实例或者负实例  
  80.     //具体计算熵 根据[+count[0],-count[1]],log2为底通过换底公式换成自然数底数  
  81.     double sum = count[0] + count[1];  
  82.     double entropy = -count[0]/sum*log(count[0]/sum)/log(2.0) - count[1]/sum*log(count[1]/sum)/log(2.0);  
  83.     return entropy;  
  84. }  
  85.       
  86. //计算按照属性attribute划分当前剩余实例的信息增益  
  87. double ComputeGain(vector <vector <string> > remain_state, string attribute){  
  88.     unsigned int j,k,m;  
  89.     //首先求不做划分时的熵  
  90.     double parent_entropy = ComputeEntropy(remain_state, attribute, blank, true);  
  91.     double children_entropy = 0;  
  92.     //然后求做划分后各个值的熵  
  93.     vector<string> values = map_attribute_values[attribute];  
  94.     vector<double> ratio;  
  95.     vector<int> count_values;  
  96.     int tempint;  
  97.     for(m = 0; m < values.size(); m++){  
  98.         tempint = 0;  
  99.         for(k = 1; k < MAXLEN - 1; k++){  
  100.             if(!attribute_row[k].compare(attribute)){  
  101.                 for(j = 1; j < remain_state.size(); j++){  
  102.                     if(!remain_state[j][k].compare(values[m])){  
  103.                         tempint++;  
  104.                     }  
  105.                 }  
  106.             }  
  107.         }  
  108.         count_values.push_back(tempint);  
  109.     }  
  110.       
  111.     for(j = 0; j < values.size(); j++){  
  112.         ratio.push_back((double)count_values[j] / (double)(remain_state.size()-1));  
  113.     }  
  114.     double temp_entropy;  
  115.     for(j = 0; j < values.size(); j++){  
  116.         temp_entropy = ComputeEntropy(remain_state, attribute, values[j], false);  
  117.         children_entropy += ratio[j] * temp_entropy;  
  118.     }  
  119.     return (parent_entropy - children_entropy);   
  120. }  
  121.   
  122. int FindAttriNumByName(string attri){  
  123.     for(int i = 0; i < MAXLEN; i++){  
  124.         if(!state[0][i].compare(attri)) return i;  
  125.     }  
  126.     cerr<<"can't find the numth of attribute"<<endl;   
  127.     return 0;  
  128. }  
  129.   
  130. //找出样例中占多数的正/负性  
  131. string MostCommonLabel(vector <vector <string> > remain_state){  
  132.     int p = 0, n = 0;  
  133.     for(unsigned i = 0; i < remain_state.size(); i++){  
  134.         if(!remain_state[i][MAXLEN-1].compare(yes)) p++;  
  135.         else n++;  
  136.     }  
  137.     if(p >= n) return yes;  
  138.     else return no;  
  139. }  
  140.   
  141. //判断样例是否正负性都为label  
  142. bool AllTheSameLabel(vector <vector <string> > remain_state, string label){  
  143.     int count = 0;  
  144.     for(unsigned int i = 0; i < remain_state.size(); i++){  
  145.         if(!remain_state[i][MAXLEN-1].compare(label)) count++;  
  146.     }  
  147.     if(count == remain_state.size()-1) return true;  
  148.     else return false;  
  149. }  
  150.   
  151. //计算信息增益,DFS构建决策树  
  152. //current_node为当前的节点  
  153. //remain_state为剩余待分类的样例  
  154. //remian_attribute为剩余还没有考虑的属性  
  155. //返回根结点指针  
  156. Node * BulidDecisionTreeDFS(Node * p, vector <vector <string> > remain_state, vector <string> remain_attribute){  
  157.     //if(remain_state.size() > 0){  
  158.         //printv(remain_state);  
  159.     //}  
  160.     if (p == NULL)  
  161.         p = new Node();  
  162.     //先看搜索到树叶的情况  
  163.     if (AllTheSameLabel(remain_state, yes)){  
  164.         p->attribute = yes;  
  165.         return p;  
  166.     }  
  167.     if (AllTheSameLabel(remain_state, no)){  
  168.         p->attribute = no;  
  169.         return p;  
  170.     }  
  171.     if(remain_attribute.size() == 0){//所有的属性均已经考虑完了,还没有分尽  
  172.         string label = MostCommonLabel(remain_state);  
  173.         p->attribute = label;  
  174.         return p;  
  175.     }  
  176.   
  177.     double max_gain = 0, temp_gain;  
  178.     vector <string>::iterator max_it = remain_attribute.begin();  
  179.     vector <string>::iterator it1;  
  180.     for(it1 = remain_attribute.begin(); it1 < remain_attribute.end(); it1++){  
  181.         temp_gain = ComputeGain(remain_state, (*it1));  
  182.         if(temp_gain > max_gain) {  
  183.             max_gain = temp_gain;  
  184.             max_it = it1;  
  185.         }  
  186.     }  
  187.     //下面根据max_it指向的属性来划分当前样例,更新样例集和属性集  
  188.     vector <string> new_attribute;  
  189.     vector <vector <string> > new_state;  
  190.     for(vector <string>::iterator it2 = remain_attribute.begin(); it2 < remain_attribute.end(); it2++){  
  191.         if((*it2).compare(*max_it)) new_attribute.push_back(*it2);  
  192.     }  
  193.     //确定了最佳划分属性,注意保存  
  194.     p->attribute = *max_it;  
  195.     vector <string> values = map_attribute_values[*max_it];  
  196.     int attribue_num = FindAttriNumByName(*max_it);  
  197.     new_state.push_back(attribute_row);  
  198.     for(vector <string>::iterator it3 = values.begin(); it3 < values.end(); it3++){  
  199.         for(unsigned int i = 1; i < remain_state.size(); i++){  
  200.             if(!remain_state[i][attribue_num].compare(*it3)){  
  201.                 new_state.push_back(remain_state[i]);  
  202.             }  
  203.         }  
  204.         Node * new_node = new Node();  
  205.         new_node->arrived_value = *it3;  
  206.         if(new_state.size() == 0){//表示当前没有这个分支的样例,当前的new_node为叶子节点  
  207.             new_node->attribute = MostCommonLabel(remain_state);  
  208.         }  
  209.         else   
  210.             BulidDecisionTreeDFS(new_node, new_state, new_attribute);  
  211.         //递归函数返回时即回溯时需要1 将新结点加入父节点孩子容器 2清除new_state容器  
  212.         p->childs.push_back(new_node);  
  213.         new_state.erase(new_state.begin()+1,new_state.end());//注意先清空new_state中的前一个取值的样例,准备遍历下一个取值样例  
  214.     }  
  215.     return p;  
  216. }  
  217.   
  218. void Input(){  
  219.     string s;  
  220.     while(cin>>s,s.compare(end) != 0){//-1为输入结束  
  221.         item[0] = s;  
  222.         for(int i = 1;i < MAXLEN; i++){  
  223.             cin>>item[i];  
  224.         }  
  225.         state.push_back(item);//注意首行信息也输入进去,即属性  
  226.     }  
  227.     for(int j = 0; j < MAXLEN; j++){  
  228.         attribute_row.push_back(state[0][j]);  
  229.     }  
  230. }  
  231.   
  232. void PrintTree(Node *p, int depth){  
  233.     for (int i = 0; i < depth; i++) cout << '\t';//按照树的深度先输出tab  
  234.     if(!p->arrived_value.empty()){  
  235.         cout<<p->arrived_value<<endl;  
  236.         for (int i = 0; i < depth+1; i++) cout << '\t';//按照树的深度先输出tab  
  237.     }  
  238.     cout<<p->attribute<<endl;  
  239.     for (vector<Node*>::iterator it = p->childs.begin(); it != p->childs.end(); it++){  
  240.         PrintTree(*it, depth + 1);  
  241.     }  
  242. }  
  243.   
  244. void FreeTree(Node *p){  
  245.     if (p == NULL)  
  246.         return;  
  247.     for (vector<Node*>::iterator it = p->childs.begin(); it != p->childs.end(); it++){  
  248.         FreeTree(*it);  
  249.     }  
  250.     delete p;  
  251.     tree_size++;  
  252. }  
  253.   
  254. int main(){  
  255.     Input();  
  256.     vector <string> remain_attribute;  
  257.       
  258.     string outlook("Outlook");  
  259.     string Temperature("Temperature");  
  260.     string Humidity("Humidity");  
  261.     string Wind("Wind");  
  262.     remain_attribute.push_back(outlook);  
  263.     remain_attribute.push_back(Temperature);  
  264.     remain_attribute.push_back(Humidity);  
  265.     remain_attribute.push_back(Wind);  
  266.     vector <vector <string> > remain_state;  
  267.     for(unsigned int i = 0; i < state.size(); i++){  
  268.         remain_state.push_back(state[i]);   
  269.     }  
  270.     ComputeMapFrom2DVector();  
  271.     root = BulidDecisionTreeDFS(root,remain_state,remain_attribute);  
  272.     cout<<"the decision tree is :"<<endl;  
  273.     PrintTree(root,0);  
  274.     FreeTree(root);  
  275.     cout<<endl;  
  276.     cout<<"tree_size:"<<tree_size<<endl;  
  277.     return 0;  
  278. }  

输入的训练数据如下

 

 

[plain] view plaincopy
 
  1. Day Outlook Temperature Humidity Wind PlayTennis  
  2. 1 Sunny Hot High Weak no  
  3. 2 Sunny Hot High Strong no  
  4. 3 Overcast Hot High Weak yes  
  5. 4 Rainy Mild High Weak yes  
  6. 5 Rainy Cool Normal Weak yes  
  7. 6 Rainy Cool Normal Strong no  
  8. 7 Overcast Cool Normal Strong yes  
  9. 8 Sunny Mild High Weak no  
  10. 9 Sunny Cool Normal Weak yes  
  11. 10 Rainy Mild Normal Weak yes  
  12. 11 Sunny Mild Normal Strong yes  
  13. 12 Overcast Mild High Strong yes  
  14. 13 Overcast Hot Normal Weak yes  
  15. 14 Rainy Mild High Strong no  
  16. end  


程序输出决策树如下

 

可以用图形表示为



有了决策树后,就可以根据气候条件做预测了

例如如果气候数据是{Sunny,Cool,Normal,Strong} ,根据决策树到左侧的yes叶节点,可以判定会去游泳。

另外在编写这个程序时在数据结构的设计上面走了弯路,多叉树的实现有很多方法,本算法采用每个结点的所有孩子用vector保存比较合适,同时注意维护剩 余样例和剩余属性信息,建树时横向遍历靠循环属性的值,纵向遍历靠递归调用 ,总体是DFS,树和图的遍历在编程时经常遇到,得熟练掌握。程序有些地方的效率还得优化,有不足的点地方还望大家拍砖。

posted @ 2014-01-27 11:35  xx ee  阅读(7550)  评论(3编辑  收藏  举报