CART:分类与回归树
起源:决策树切分数据集
决策树每次决策时,按照一定规则切分数据集,并将切分后的小数据集递归处理。这样的处理方式给了线性回归处理非线性数据一个启发。
能不能先将类似特征的数据切成一小部分,再将这一小部分放大处理,使用线性的方法增加准确率呢?
Part I: 树的枝与叶
枝:二叉 or 多叉?
在AdaBoost的单决策树中,对于连续型数据构建决策树,我们采取步进阈值切分2段的方法。还有一种简化处理,即选择子数据集中的当前维度所有不同的值作为阈值切分。
而在CART里,大于阈值归为左孩子,小于阈值的归为右孩子。若是离散型数据,则根据离散数据种类建立对应的多叉树即可。
叶:何时不再切分?
ID3决策树中,停止切分的条件有两个:
①DFS链路中全部切分方式被扫过一次,很明显,对于离散型特征,每次按照某一维度异同切分,再扫相同维度毫无意义。
对于连续型特征,则共有维度*(当前维度不同值数量,即阈值数量)种切分方式,同种方式也毫无意义。
②当前子数据集分类全部一致,已经是很完美的切分了,再切也没意思。
CART中,由于搜索深度只有1,重复选取也不会卡死。所以直接遵循②。
追加③条件:手动限制切分子集数量下限tolN,误差变化下限tolS(目标函数收敛)。一旦达到这两个下限,就立刻停止。
枝:一个好枝?
ID3算法给出了一个评价离散Label的好枝的标准:分类的混乱度(香农熵)降低。
对于连续数据,好枝的参考标准则是类似最小二乘法的目标函数,即误差越小越好。
由于计算误差需要先进行线性回归,相当于树套回归,虽然效果很好,但是无疑带来计算压力。
在这点上, CART利用均值和方差的性质给出了一个简化的误差计算:即假设一团数据的回归结果是这团数据的均值,那么目标函数即可当成总方差。
使用均值替代回归结果的树称为回归树,使用实际回归结果的树成为模型树。
叶:数量越多越好?
叶结点数量越多,越容易过拟合。数量越少,则容易欠拟合。
而tolN和tolS在选择最好的切分方式时,控制着叶结点的数量,这两个值越小,叶子越多,且对tolS的值很敏感。
树的递归构建:
①对当前数据集做最好的切分。
②若不能切分,则将该结点设为叶结点。
否则,由于切分的性质,所以切出的两个子集必定不为空。对大于阈值的子集进行左孩子递归构建,小于阈值的子集进行右孩子递归构建。
Part II : 树的剪枝
叶结点数量决定着拟合情况。人工调整不是一件好事。
所以出现一种先强行过拟合(tolN=0,tolS=1)生成CART树,然后利用新的样本数据进行剪枝的方法,称为后剪枝。
后剪枝有两种方法:
①后剪枝会将大量的枝从树顶直接转化成叶子,相当于废掉原树中很多数据,所以需要引入新的数据。
而把一个大枝转为叶子的方法,则是利用均值的性质。新叶子的回归值=原枝上所有叶的均值。
②除了废枝为叶,还有利用均值的计算性质、借助新数据归并两叶。当然归并是有条件的。
新数据递归切分之后,必然会分到叶子上。如果恰好一个枝上是两片叶子,那么分别计算ErrNoMerga、ErrMerga的值,观察是否变小来决定是否归并。
$ErrNoMerga=\sum_{i=1}^{LSet}(Set[i].y-L.leaf)^{2}+\sum_{i=1}^{RSet}(Set[i].y-R.leaf)^{2}$
$NewLeaf=mergaMean=avg(L.leaf+R.leaf)$
$ErrMerga=\sum_{i=1}^{Set}{(Set[i].y-mergaMean)^{2}}$
Part III: 回归与模型树
对于每条测试数据,从树顶按照树中保存的切分规则左右递归直到叶结点,返回叶结点的值作为回归值。
实际测试结果下,效果并不好。所以应当每一个叶结点:保留数据、以及线性回归方程(w、b),从而建立起模型树。
线性模型树方法将取代回归树中的均值误差理论,主要修改地方在选择分支、后剪枝上。
$Err =\sum_{i=1}^{m} (data[i].y-Regression(y))^{2}$
这样,叶结点就变成了一个线性回归器,返回线性回归结果即可。
Part IV 代码
#include "cstdio" #include "iostream" #include "fstream" #include "math.h" #include "sstream" #include "string" #include "vector" #include "set" using namespace std; #define Dim dataSet[0].feature.size() #define TREE pair<vector<Data>,vector<Data> > #define NULL 0 struct Data { vector<double> feature; double y; Data(vector<double> feature,double y):feature(feature),y(y) {} }; struct RegTree { int dim;double value; RegTree *Left,*Right; RegTree():Left(NULL),Right(NULL) {} RegTree(int dim,double value):Left(NULL),Right(NULL),dim(dim),value(value) {} }; vector<Data> dataSet,addSet,testSet; pair<int,double> ops(0,1); void read() { ifstream fin("data1.txt"),fin2("data2.txt"),fin3("data3.txt"); string line;double tmp,y; while(getline(fin,line)) { stringstream sin(line); vector<double> feature; while(sin>>tmp) feature.push_back(tmp); y=feature.back();feature.pop_back(); dataSet.push_back(Data(feature,y)); } while(getline(fin2,line)) { stringstream sin(line); vector<double> feature; while(sin>>tmp) feature.push_back(tmp); y=feature.back();feature.pop_back(); addSet.push_back(Data(feature,y)); } while(getline(fin3,line)) { stringstream sin(line); vector<double> feature; while(sin>>tmp) feature.push_back(tmp); y=feature.back();feature.pop_back(); testSet.push_back(Data(feature,y)); } } pair<vector<Data>,vector<Data> > splitDataSet(vector<Data> dataSet,int dim,double value) { vector<Data> Left,Right; for(int i=0;i<dataSet.size();i++) { if(dataSet[i].feature[dim]>value) Left.push_back(dataSet[i]); else Right.push_back(dataSet[i]); } return make_pair(Left,Right); } double regLeaf(vector<Data> dataSet) { double ret=0.0; //printf("Leaf:\n"); for(int i=0;i<dataSet.size();i++) { ret+=dataSet[i].y; /* for(int j=0;j<dataSet[i].feature.size();j++) printf("%.2lf ",dataSet[i].feature[j]); printf("%lf\n",dataSet[i].y);*/ } //printf("\n"); return ret/dataSet.size(); } double calcErr(vector<Data> dataSet) { double avg=0.0,ret=0.0; for(int i=0;i<dataSet.size();i++) avg+=dataSet[i].y; avg/=dataSet.size(); for(int i=0;i<dataSet.size();i++) ret+=(dataSet[i].y-avg)*(dataSet[i].y-avg); return ret; } pair<int,double> chooseBestSplit(vector<Data> dataSet) { //tolN、tolS(较敏感)过小都会导致Leaf过多,过大则会导致Leaf过少 int tolN=ops.first;double tolS=ops.second,S,newS,bestS=1e10,bestValue,bestDim; set<double> y; for(int i=0;i<dataSet.size();i++) y.insert(dataSet[i].y); if(y.size()==1) return make_pair(-1,regLeaf(dataSet)); S=calcErr(dataSet); for(int i=0;i<Dim;i++) { set<double> splitValue; for(int j=0;j<dataSet.size();j++) splitValue.insert(dataSet[j].feature[i]); for(set<double>::iterator j=splitValue.begin();j!=splitValue.end();j++) { TREE tree=splitDataSet(dataSet,i,*j); if(tree.first.size()<tolN||tree.second.size()<tolN) continue; newS=calcErr(tree.first)+calcErr(tree.second); if(newS<bestS) {bestDim=i;bestValue=*j;bestS=newS;} } } if(S-bestS<tolS) return make_pair(-1,regLeaf(dataSet)); TREE tree=splitDataSet(dataSet,bestDim,bestValue); if(tree.first.size()<tolN||tree.second.size()<tolN) return make_pair(-1,regLeaf(dataSet)); return make_pair(bestDim,bestValue); } RegTree *buildTree(vector<Data> dataSet) { pair<int,double> info=chooseBestSplit(dataSet); if(info.first==-1) { RegTree *node=new RegTree(info.first,info.second); return node; } RegTree *node=new RegTree(info.first,info.second); TREE tree=splitDataSet(dataSet,info.first,info.second); //printf("Node: dim:%d %.2lf\n",info.first,info.second); node->Left=buildTree(tree.first); node->Right=buildTree(tree.second); return node; } double getMean(RegTree *root) { double ret=0.0; if(root->Left->dim!=-1) ret+=getMean(root->Left); else ret+=root->Left->value; if(root->Right->dim!=-1) ret+=getMean(root->Right); else ret+=root->Right->value; return ret/=2; } RegTree *prune(RegTree *&root,vector<Data> dataSet) { if(dataSet.size()==0) return new RegTree(-1,getMean(root)); double errNoMerga=0.0,errMerga=0.0; if(root->Left->dim!=-1||root->Right->dim!=-1) { TREE tree=splitDataSet(dataSet,root->dim,root->value); if(root->Left->dim!=-1) root->Left=prune(root->Left,tree.first); if(root->Right->dim!=-1) root->Right=prune(root->Right,tree.second); } if(root->Left->dim==-1&&root->Right->dim==-1) { TREE tree=splitDataSet(dataSet,root->dim,root->value); for(int i=0;i<tree.first.size();i++) errNoMerga+=(tree.first[i].y-root->Left->value)*(tree.first[i].y-root->Left->value); for(int i=0;i<tree.second.size();i++) errNoMerga+=(tree.second[i].y-root->Right->value)*(tree.second[i].y-root->Right->value); double mergaMean=(root->Left->value+root->Right->value)/2; for(int i=0;i<dataSet.size();i++) errMerga+=(dataSet[i].y-mergaMean)*(dataSet[i].y-mergaMean); if(errMerga<errNoMerga) {/*cout<<"Merga"<<endl;*/return new RegTree(-1,mergaMean);} else return root; } return root; } int ccnt=0; void displayTree(RegTree *root) { if(root->Left->dim!=-1) displayTree(root->Left); else {printf("Leaf:%.2lf\n",root->Left->value);ccnt++;} if(root->Right->dim!=-1) displayTree(root->Right); else {printf("Leaf:%.2lf\n",root->Right->value);ccnt++;} } double forcast(RegTree *root,Data data) { if(root->dim==-1) return root->value; //in case the super root is a leaf if(data.feature[root->dim]>root->value) { if(root->Left->dim!=-1) return forcast(root->Left,data); else return root->Left->value; } else { if(root->Right->dim!=-1) return forcast(root->Right,data); else return root->Right->value; } } void forcastAll(RegTree *root,vector<Data> dataSet) { for(int i=0;i<dataSet.size();i++) { double y=forcast(root,dataSet[i]); printf("origin:%.2lf forcast:%.2lf\n",dataSet[i].y,y); } } int main() { read(); RegTree *root=buildTree(dataSet); root=prune(root,addSet); forcastAll(root,testSet); }