ID3决策树算法
基本概念:
信息熵是信息的一种不确定的程度的度量。假定一个系统s具有概率分布p={pi}(0<=pi<=1),i=1,2,3,4,...,n,则系统s的信息熵定义为。假设X是一个集合,如果存在一组集合A1,A2,A3,...,An,满足下列条件则称A1-An是集合X的一个划分。
ID3算法使用信息熵作为度量标准,选择信息熵最小的属性作为分类属性,完成决策树的构造,其中属性的熵定义为该属性单个属性值得权熵之和。在生成树的过程中,每个节点只有一个属性值(权熵相同的属性值看成一个属性值)。
树的递归结束条件是,划分的集合是否属于同一类,或者是否达到了所要求的深度,或者某个类的个数达到了一定的阈值。
这是我写的一个ID3算法的例子:
#include<stdio.h> #include<math.h> #include<string.h> #include<stdlib.h> #define SHORT 0 #define MEDIUM 1 #define TALL 2 #define MAIL 0 #define FEMAIL 1 #define GENDER 0 #define HEIGHT 1 #define KIND 2 #define LEAF -1 typedef struct TNODE { int attribute; int arriv_value; struct TNODE *child[50]; int childCount; int classification; } Node; int attriCnt[10]={2,6}; int classCnt=2; int trainingData[100][30]; int testData[100][3]; double Entropy(int *indexArray/*需要统计元组下标*/,int len/*元组的个数*/) { /* 1.统计某个属性种类得个数 2.使用log计算出值,返回 */ double sum=0; int i,j; int cnt[10]; memset(cnt,0,sizeof(cnt)); for(j=0;j<classCnt;j++) { for(i=0;i<len;i++) { if(trainingData[indexArray[i]][KIND]==j/*等于某个属性值*/) { cnt[j]++;/*该lei值个数+1*/ } } } for(i=0;i<classCnt;i++) { if(cnt[i]==0)continue; double temp=log(cnt[i]*1.0/len)/log(2); // printf("cnt: %d\n",cnt[i]); // printf("log: %lf\n",log(cnt[i]*1.0/len)); // printf("temp: %lf\n",temp); sum=sum-cnt[i]*1.0/len*temp; } return sum; } double Grain(int *indexArray,int attri,int len)//每次调用grain的环境可能不一样:indexArray { int i,j; double h; double hd=Entropy(indexArray,len); // printf("in grain function,hd:%lf\n",hd); int subIndexArray[10]; int sublen; double result=0; for(i=0;i<attriCnt[attri];i++) { sublen=0; for(j=0;j<len;j++) { if(trainingData[indexArray[j]][attri]==i/*如果该属性是某个值*/) { subIndexArray[sublen++]=indexArray[j];/*统计该属性值得个数,记录下标存入数组当中以便计算*/ } } /*for(j=0;j<sublen;j++) { printf("%d\t",subIndexArray[j]); }printf("\n");*/ h=Entropy(subIndexArray,sublen);//计算熵 //printf("in grain function,h:%lf\n",h); result=result+sublen*1.0/len*h; } result=hd-result; return result; } int toClass(int *chooseIndex,int lines) { int i; int cnt[3]; cnt[0]=cnt[1]=cnt[2]=0; /* for(i=0;i<lines;i++) { printf("chooseIndex: %d\t",chooseIndex[i]); }printf("\n");*/ for(i=0;i<lines;i++) { cnt[trainingData[chooseIndex[i]][KIND]]++; } int maxv=-1; int flag=0; for(i=0;i<3;i++) { if(maxv<cnt[i]){maxv=cnt[i];flag=i;} } //printf("maxv: %d\n",maxv); // printf("flag: %d\n",flag); return flag; } int check_attribute(int *chooseIndex,int len)//检查所有得元组是否都是一类 { /* 1.扫描所有得元组,如果出现不适同一类得元组,则返回 */ int i; for(i=1;i<len;i++) { if(trainingData[chooseIndex[i]][KIND]!=trainingData[chooseIndex[i-1]][KIND]) { return 0; } } return 1; } Node *buildTree(int *chooseIndex/*选中的元组*/,int lines/*元组个数*/,int *remain_attribute/*剩下未分类的属性*/,int attriNumber/*属性得个数*/,int arriv_value) { //错误:递归结束条件错 int i,j; // printf("attriNumber: %d\n",attriNumber); // printf("lines: %d\n",lines); /*for(i=0;i<lines;i++) { printf("chooseIndex: %d\t",chooseIndex[i]); }printf("\n");*/ if(lines==0)return NULL; int choose_attribute; double maxgrain=-1; int flag=check_attribute(chooseIndex,lines); if(flag==1)/*属性相同的时候,停止递归*/ { Node *no=(Node *)malloc(sizeof(Node)); no->attribute=LEAF; no->childCount=0; no->arriv_value=arriv_value; no->classification=toClass(chooseIndex,lines); for(i=0;i<50;i++)no->child[i]=NULL; return no; } else if(attriNumber==1) { choose_attribute=remain_attribute[0]; } else { for(i=0;i<attriNumber;i++)//选中最大得增益值 { double temp=Grain(chooseIndex,remain_attribute[i],lines); // printf("temp: %lf\t",temp); if(temp>maxgrain) { maxgrain=temp; choose_attribute=remain_attribute[i]; } } //printf("\n"); } /*确定剩下得属性*/ int subRemain_attribute[10]; int k=0; for(i=0;i<attriNumber;i++)//计算未使用得属性 { if(remain_attribute[i]!=choose_attribute) { subRemain_attribute[k++]=remain_attribute[i]; } } /*新建节点*/ Node *no=(Node *)malloc(sizeof(Node)); no->attribute=choose_attribute; no->childCount=attriCnt[choose_attribute]; no->arriv_value=arriv_value; no->classification=-1; for(i=0;i<50;i++)no->child[i]=NULL; for(i=0;i<attriCnt[choose_attribute];i++) { int subChooseIndex[100]; int subLines=0; for(j=0;j<lines;j++) { if(trainingData[chooseIndex[j]][choose_attribute]==i) { subChooseIndex[subLines++]=chooseIndex[j]; } } no->child[i]=buildTree(subChooseIndex,subLines,subRemain_attribute,k,i); } return no; } void blank(int deep) { int i; for(i=0;i<deep;i++)printf("\t\t"); } void Triverse(Node *root,int deep) { if(root==NULL)return; int i; blank(deep); switch (root->attribute) { case GENDER:printf(" classification:gender\n");blank(deep);break; case HEIGHT:printf("calssification:height\n");blank(deep);break; case LEAF:printf("leaf arrived\n");blank(deep);break; default:printf("%d\n",root->attribute);blank(deep); } printf("arriv_value: %d\n",root->arriv_value);blank(deep); printf("childCount: %d\n",root->childCount);blank(deep); printf("classification: %d\n",root->classification);blank(deep); printf("------------------------------------------\n"); for(i=0;i<root->childCount;i++) { Triverse(root->child[i],deep+1); } } void Classify(int lineNumber,Node *root) { if(root==NULL) { printf("classify failed!\n"); return; } if(root->child[0]==NULL)//如果到达了叶子节点 { int choice=root->classification; switch (choice) { case 0:printf("the training data belongs to Short\n");break; case 1:printf("the training data belongs to Medium\n");break; case 2:printf("the training data belongs to Tall\n");break; default: printf("classify failed!\n");break; } return; } int classifyAttribute=root->attribute; int childIndex=testData[lineNumber][classifyAttribute]; Classify(lineNumber,root->child[childIndex]); } int main() { FILE *fp; fp=fopen("./data.txt","r"); if(fp==NULL) { printf("Can not open file\n"); return 0; } char name[10],kind[10],gender[10]; double height; int lines=0; while(fscanf(fp,"%s",name)!=EOF) { fscanf(fp,"%s",gender); if(!strcmp(gender,"F")) { trainingData[lines][0]=FEMAIL; } else trainingData[lines][0]=MAIL; fscanf(fp,"%lf",&height); if(height>=1.6&&height<1.7) { trainingData[lines][1]=0; } else if(height>=1.7&&height<1.8) { trainingData[lines][1]=1; } else if(height>=1.8&&height<1.9) { trainingData[lines][1]=2; } else if(height>=1.9&&height<2.0) { trainingData[lines][1]=3; } else if(height>=2.0&&height<2.1) { trainingData[lines][1]=4; } else if(height>=2.1&&height<=2.2) { trainingData[lines][1]=5; } fscanf(fp,"%s",kind); if(!strcmp(kind,"Short")) { trainingData[lines][2]=SHORT; } else if(!strcmp(kind,"Medium")) { trainingData[lines][2]=MEDIUM; } else trainingData[lines][2]=TALL; lines++; } //printf("lines: %d\n",lines); int i,j; /* for(i=0;i<lines;i++) { for(j=0;j<=2;j++)printf("%d\t",trainingData[i][j]); printf("\n"); }*/ fclose(fp);fp=NULL; int index[100],remain_attribute[100]; for(i=0;i<lines;i++)index[i]=i; for(i=0;i<2;i++)remain_attribute[i]=i; printf("print the decision tree:\n"); Node *root=buildTree(index,lines,remain_attribute,2,-1); Triverse(root,0); printf("The training data is:\n"); fp=fopen("./testData.txt","r"); if(fp==NULL) { printf("Can not open the file!\n"); return 0; } int testlines=0; while(fscanf(fp,"%s",name)!=EOF) { printf("%s\t",name); fscanf(fp,"%s",gender); printf("%s\t",gender); if(!strcmp(gender,"F")) { testData[testlines][0]=FEMAIL; } else testData[testlines][0]=MAIL; fscanf(fp,"%lf",&height); printf("%lf\n",height); if(height>=1.6&&height<1.7) { testData[testlines][1]=0; } else if(height>=1.7&&height<1.8) { testData[testlines][1]=1; } else if(height>=1.8&&height<1.9) { testData[testlines][1]=2; } else if(height>=1.9&&height<2.0) { testData[testlines][1]=3; } else if(height>=2.0&&height<2.1) { testData[testlines][1]=4; } else if(height>=2.1&&height<=2.2) { testData[testlines][1]=5; } testlines++; } /*for(i=0;i<testlines;i++) { for(j=0;j<2;j++) { printf("%d\t",testData[i][j]); }printf("\n"); }*/ Classify(0,root); }