CART分类器

cart 算法采用二分递归回归技术,将当前的样本集分为两个子样本集,使得生成得每个非叶子节点都有两个分支。所以,算法生成得决策树是简洁得二叉树。

分类树得两个基本思想:第一个是将训练样本进行递归地划分自变量空间进行建树得想法,第二个想法是用验证数据进行剪枝。

cart进行属性分类得是用gini指标

如果我们用k,k=1,2,3……C表示类,其中C是类别集Result的因变量数目,一个节点A的GINI不纯度定义为:

其中,Pk表示观测点中属于k类得概率,当Gini(A)=0时所有样本属于同一类,当所有类在节点中以相同的概率出现时,Gini(A)最大化,此时值为(C-1)C/2。

对于分类回归树,A如果它不满足“T都属于同一类别or T中只剩下一个样本”,则此节点为非叶节点,所以尝试根据样本的每一个属性及可能的属性值,对样本的进行二元划分,假设分类后A分为B和C,其中B占A中样本的比例为p,C为q(显然p+q=1)。则杂质改变量:Gini(A) -p*Gini(B)-q*Gini(C),每次划分该值应为非负,只有这样划分才有意义,对每个属性值尝试划分的目的就是找到杂质gai变量最大的一个划分,该属性值划分子树即为最优分支。

作业源码:

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#define EPS 0.000001
/*
 1.为了节约时间,直接在c4.5的基础上将其转化为cart算法。
 2.使用gini指标评判杂质量。
 3.非连续变量选择分割使用异化程度最小的对应属性值划分,即x和非x
 4.连续变量使用和c4.5划分方法一样得划分,但是是使用gini指标进行划分
 */
typedef struct Tuple
{
	int i;
    int g;
    double h;
    int c;
}tuple;
typedef struct TNode{
    double  gap;
    int attri;
    int reachValue;
    struct TNode *child[50];
    int kind;
}node;
tuple trainData[100];
double cal_entropy(tuple *data,int len);
double choose_best_gap(tuple *data,int len);
double cal_grainRatio(tuple *data,int len);
double cal_grainRatio2(tuple *data,int len,double gap);
double cal_splitInfo(tuple *data,int len);
int check_attribute(tuple *data,int len);
int choose_attribute(tuple *data,int len);
node *build_tree(tuple *data,int len,double reachValue,double gap);
void print_blank(int depth);
void traverse(node *no,int depth);
void test_data(node *root,tuple *data);
int cmp(const void *a, const void *b)
{
    tuple *a1=(tuple *)a;
    tuple *b1=(tuple *)b;
    return a1->h-b1->h>0?1:-1;
}
void copy_tuple(tuple *source,tuple *destination)
{
    destination->c=source->c;
    destination->g=source->g;
    destination->h=source->h;
	destination->i=source->i;
}
double cal_gini(tuple *data,int len)
{
    int i,j;
    double result=0.0;
    int cnt;
    for(i=0;i<3;i++)//有三类
    {
        cnt=0;
        for(j=0;j<len;j++)
        {
            if(data[j].c==i)
            {
                cnt++;
            }
        }
        result+=(cnt*1.0/len)*(cnt*1.0/len);
    }
	//printf("in cal_gini: %lf\n",result);
    return 1-result;
}
double cal_gender(tuple *data,int len)//计算性别分类的gini差值
{
    int i,j;
    double preGini=cal_gini(data,len);//计算分类前得gini数
    tuple subData[100];
    int subLen;
    double result=0.0;
    for(i=0;i<2;i++)
    {
        subLen=0;//统计某个性别得个数
        for(j=0;j<len;j++)
        {
            if(data[j].g==i)//属于某个性别
            {
                copy_tuple(&data[j],&subData[subLen++]);//存入数组当中
            }
        }
        result=result+subLen*1.0/len*cal_gini(subData,subLen);
    }
    return preGini-result;
}
double cal_height(tuple *data, int len,int *at)//计算性别分类的gini差值
{
    int i,j;
    double preGini=cal_gini(data,len);
	//printf("preGini: %lf\n",preGini);
	//getchar();
    tuple small[100],big[100];
    int smallLen,bigLen;
    double maxv=-1;
    for(i=0;i<len;i++)//寻找最大得gini差值得各个测试单元
    {
		smallLen=0;bigLen=0;
        for(j=0;j<len;j++)
        {
            if(data[j].h<=data[i].h)
            {
                copy_tuple(&data[j],&small[smallLen++]);
            }
            else
            {
                copy_tuple(&data[j],&big[bigLen++]);
            }
        }
		//printf("i: %d\n",i);
		//printf("smallLen: %d\n",smallLen);
		//printf("bigLen: %d\n",bigLen);
        double smallGini=cal_gini(small,smallLen);
		//printf("smallGini: %lf\n",smallGini);
        double bigGini=cal_gini(big,bigLen);
		//printf("bigGini: %lf\n",bigGini);
        double temp=preGini-(smallLen*1.0/len*smallGini+bigLen*1.0/len*bigGini);
		//printf("temp: %lf\n",temp);
        if(temp>maxv)
        {
            maxv=temp;
            *at=i;
			//printf("at: %d data[at]: %lf\n",*at,data[*at].h);
        }
    }
	//printf("maxv: %lf\n",maxv);
    return maxv;
}
int main()
{
    FILE *fp;
    fp=fopen("./data.txt", "r");
    if(fp==NULL)
    {
        printf("can not open the file: data.txt\n");
        return 0;
    }
    char name[50];
    double height;
    char gender[10];
    char kind[10];
    int i=0;
    while(fscanf(fp, "%s",name)!=EOF)
    {
		trainData[i].i=i;
        fscanf(fp,"%s",gender);
        if(!strcmp(gender, "M"))
        {
            trainData[i].g=0;
        }
        else trainData[i].g=1;
        fscanf(fp,"%lf",&height);
        trainData[i].h=height;
        fscanf(fp,"%s",kind);
        if(!strcmp(kind, "Short"))
        {
            trainData[i].c=0;
        }
        else if(!strcmp(kind,"Medium"))
        {
            trainData[i].c=1;
        }
        else{
            trainData[i].c=2;
        }
        i++;
    }
    int rows=i;
	node *root=build_tree(trainData,rows,-1,-1);
	 traverse(root,0);printf("\n");
	 fp=fopen("./testData.txt", "r");
	     if(fp==NULL)
	     {
	         printf("can not open the file!\n");
	         return 0;
	     }
	     tuple testData;
	     fscanf(fp, "%s",name);
	     fscanf(fp,"%s",gender);
	     if(!strcmp(gender, "M"))
	     {
	         testData.g=0;
	     }
	     else  testData.g=1;
	     fscanf(fp,"%lf",&height);
	      testData.h=height;
	   //  printf("testData: gender: %d\theight: %lf\n",testData.g,testData.h);
		 fclose(fp);
		 fp=NULL;
		 test_data(root,&testData);
}
void test_data(node *root,tuple *data)
{
	/*
     1.检查节点得属性值
     2.如果是身高则检查gap得值如果<=就往左,否则就往右
     3.如果是性别就判断reachValue的值
     */
    if(root->attri==-1)
    {
        printf("the test data belongs to:");
        switch (root->kind) {
            case 0: printf("Short\n");break;
            case 1: printf("Medium\n");break;
            case 2: printf("Tall\n");break;
            default:break;
        }
		return;
    }
	if(root->attri==0)
    {
        if(data->g==0)
        {
            test_data(root->child[0],data);
        }
        else
        {
            test_data(root->child[1], data);
        }
    }
    else
    {
		//printf("gap: %lf\n",root->gap);
        if(data->h<=root->gap)
        {
            test_data(root->child[0], data);
        }
        else{
            test_data(root->child[1], data);
        }
    }
}

void print_blank(int depth)
{
    int i;
    for(i=0;i<depth;i++)
    {
        printf("\t");
    }
}
void traverse(node *no,int depth)
{
    if(no==NULL)return;
    int i;
	printf("-------------------\n");
	print_blank(depth);
    printf("attri: %d\n",no->attri);print_blank(depth);
    printf("gap: %lf\n",no->gap);print_blank(depth);
    printf("kind: %d\n",no->kind);print_blank(depth);
    printf("reachValue: %d\n",no->reachValue);print_blank(depth);
	printf("-------------------\n");print_blank(depth);
    for(i=0;no->child[i]!=NULL;i++)
    {
        traverse(no->child[i], depth+1);
    }
}
int choose_attribute(tuple *data,int len)//选择属性函数,返回代表属性的代号
{
    int i;
    /*
     1.如果是性别,就直接计算增益
     2.如果是身高就计算最高得增益值得gap
     3.性别和身高得增益进行比较的到最佳得分类属性
     */
    double genderGini=cal_gender(data, len);
    int heightChoice;
    double heightGini=cal_height(data,len,&heightChoice);
    if(genderGini<heightGini)
    {
        return 1;
    }
    else
    {
        return 0;
    }
	//printf("gGrainRatio: %lf\n",gGrainRatio);
    /*计算连续属性值的增益
     1.排序确定gap
     2.计算各个gap的信息增益率
     3.选定最大得信息增益率确定该属性的最大信息增益率
     */
}
node *build_tree(tuple *data,int len,double reachValue,double gap)
{
	//getchar();getchar();
    int i,j;
	/*for(i=0;i<len;i++)
	{
		printf("data i: %d g:%d h:%lf c:%d\n",data[i].i,data[i].g,data[i].h,data[i].c);
	}*/
    int kind=check_attribute(data, len);//检查所有得元组是否属于同一个类
	//printf("kind: %d\n",kind);
    if(kind!=0)//如果所有得元组都属于同一类则作为叶子节点返回
    {
	//	printf("leaves constructed completed!\n");
        node *newNode=(node *)malloc(sizeof(node));
        newNode->gap=-1;//如果是按照身高分类就用得到gap;
        newNode->attri=-1;
        newNode->reachValue=reachValue;
        newNode->kind=kind-1;
        for(i=0;i<50;i++)newNode->child[i]=NULL;//初始化所有的孩子节点
        return newNode;
    }
    //从元组中选择最优属性值进行分类
    int attribute=choose_attribute(data, len);
	//printf("choose: %d\n",attribute);
    //执行分类 深度优先构建树结构
    node *newNode=(node *)malloc(sizeof(node));
    newNode->reachValue=reachValue;
    newNode->attri=attribute;
    newNode->kind=-1;
	newNode->gap=gap;
    for(i=0;i<50;i++)newNode->child[i]=NULL;
    if(attribute==0)//选择性别进行构建
    {
        for(i=0;i<2;i++)
        {
            tuple subData[100];
            int sublen=0;
            for(j=0;j<len;j++)
            {
                if(data[j].g==i/*是男的或者女的*/)
                {
                    copy_tuple(&data[j],&subData[sublen++]);
                }
            }
			if(sublen==0)continue;
            newNode->child[i]=build_tree(subData,sublen,i,-1);//因为是用性别构建得,所以不用gap分区间取值
        }
    }
    else
    {
        //选择高度构建
        /*
         1.选择最优得分割值
         2.将元组分割成left和right两个部分
         */
        int index=0;
        double heightGini=cal_height(data,len,&index);
        double gap=data[index].h;//选择分割连续变量得值
		newNode->gap=gap;
		//printf("best gap: %lf\n",gap);
        tuple leftData[100],rightData[100];//分割完成后,放入左右两个数组里面
        int leftlen=0;//左右数组的长度
        int rightlen=0;
        for(i=0;i<len;i++)
        {
            if(data[i].h<=gap)
            {
                copy_tuple(&data[i],&leftData[leftlen++]);
            }
            else{
                copy_tuple(&data[i],&rightData[rightlen++]);
            }
        }
		if(leftlen!=0)
        newNode->child[0]=build_tree(leftData,leftlen,-1,gap);//使用身高构建子树,因此必须分区间进行
		if(rightlen!=0)
        newNode->child[1]=build_tree(rightData,rightlen,-1,gap);
    }
    return newNode;
}
int check_attribute(tuple *data,int len)//检查所有得元组是否都是一类
{
    /*
     1.扫描所有得元组,如果出现不适同一类得元组,则返回
     */
    int i;
    for(i=1;i<len;i++)
    {
        if(data[i].c!=data[i-1].c)return 0;
    }
    return data[0].c+1;
}

 

posted @ 2013-12-31 01:24  湖心北斗  阅读(811)  评论(0编辑  收藏  举报