关于回归树的创建和剪枝
之前对于树剪枝一直感到很神奇;最近参考介绍手工写了一下剪枝代码,才算理解到底什么是剪枝。
首先要明白回归树作为预测的模式(剪枝是针对回归树而言),其实是叶子节点进行预测;所以在使用回归树进行预测的时候,本质都是在通过每层(每个层代表一个属性)的值的大于和小于来作为分值,进行二叉树的遍历。最后预测值其实叶子节点中左值或者右值;注意这里的叶子结点也是一个结构体,对于非叶子节点而言,他的左右值是一棵树,但是对于叶子结点而言,左右值则是一个单一的数值。
那么剪枝的原始就是找到叶子节点,如上图所示的特征C和特征E,然后取左右值的均值,合并(merge)为一个节点。比如低于特征C,就是取值5.5,作为B树的左节点,这样特征C这个节点就被减掉了。
但是在剪枝的时候注意一定要和原始场景进行比较,未剪枝前的偏差和剪枝后偏差做一个比较,看看到底哪个更优秀;如果剪枝后MSE值反而更加大了,就不要价值了。这里偏差的计算值是sum(power(yHat- y, 2))来进行比较即可。
下面的就是剪枝的python实现:
1 # 所谓剪枝即使遍历到叶子结点,然后看一下作为预测值的叶子结点,合并左右节点(即取左右子树平均数)为一个点 2 # 但是需要比较一下合并之后的偏差和合并前的偏差,如果合并之后的方差变小了,则剪枝(取合并值),反之则保持原状 3 def prune(tree, testData): 4 m, n = shape(testData) 5 # 如果测试在分类(分割)过程,某一类数据为0 6 if(m == 0): return getMean(tree) 7 # 下面一大段其实都是在做这一件事情:深入都叶子结点 8 # 1. 只要左右子树中有一颗不是叶子结点,那么就以当前节点的spIndex以及spValue为分割(分类)点,对testData进行二元分类 9 # 获得的是二元分类的数据集left set和right set 10 if(isTree(tree["left"]) or isTree(tree["right"])): 11 lset, rset = binSplitDataset(testData, tree["spIndex"], tree["spValue"]) 12 # 2. 继续处理不是叶子结点左右子树,对其进行递归prune(本质就是要深入到叶子结点为止) 13 if(isTree(tree["left"])): 14 tree["left"] = prune(tree["left"], lset) 15 if(isTree(tree["right"])): 16 tree["right"] = prune(tree["right"], rset) 17 18 # 左右子树都是叶子节点了 19 if(not isTree(tree["left"]) and not isTree(tree["right"])): 20 # 那么就以当前叶子节点的spIndex以及spValue为分割(分类)点,对testData进行二元分类 21 lset, rset = binSplitDataset(testData, tree["spIndex"], tree["spValue"]) 22 # 计算测试数据集和预测值(叶子结点)之间的方差,剪枝前的偏差 23 errorNotMerge = sum(power(lset[:, -1] - tree["left"], 2)) + sum(power(rset[:, -1] - tree["right"],2)) 24 treeMean = (tree["left"] + tree["right"]) / 2.0 25 # 测试数据全集和树均值(预测值)之间的方差,剪枝后偏差 26 errorMerge = sum(power(testData[:, -1] - treeMean, 2)) 27 # 看看谁的方差小,如果测试数据全集和树均值的方差小,返回的是树均值(叶子结点)的均值 28 if(errorMerge < errorNotMerge): 29 #print("errorMerge < errorNotMerge, treeMean is: ") 30 #print(treeMean) 31 return treeMean 32 # 如果叶子节点(预测值)的和真实值之间的方差比较小,则返回的树,不需要剪枝 33 else: 34 #print("errorMerge > errorNotMerge, [tree] is: ") 35 #print(tree) 36 return tree 37 # 说明叶子结点剪枝效果不明显,不需要剪枝 38 else: 39 return tree 40
那么再汇过来,如何构建一个回归树呢?
构建回归树有几个条件,首先要有样本数据,叶子节点的计算方式(regLeaf),以及计算一个数据集的偏差的公式(regErr);
1 from numpy import mean 2 3 # 数据集中y值的均值 4 def regLeaf(dataset): 5 return mean(dataset[:, -1]) 6 7 # 数据集中y值的方差和 8 def regErr(dataset): 9 return var(dataset[:, -1]) * shape(dataset)[0]
有了这三者之后,就可以进行构建树了。构建树的时候,首先将会选择一个区分度最好的特征以及特征值,做样本分割,然后基于分割后的样本分别构建左子树和右子树,这是一个递归的过程,发生变化的样本,以及基于变化的样本产生的新的分割特征以及特征值,这个递归过程一直到样本数据不再可分为止,此时获得就是一个value,这个就是叶子结点的left/right值(非叶子节点left/right仍然是一棵树)。
1 # 获取最好的分割信息,这里包括分割的特征以及特征值,然后对数据进行分割,在以分割后数据为基础继续进行继续创建树,一直到数据无法再分割 2 # (feature)为none为止。 3 def createTree(dataset, leafType=regLeaf, errorType=regErr, ops=(1, 4)): 4 feature, value = chooseBsetSplit(dataset, leafType, errorType, ops) 5 # left/right值直接就是数字(不再是树了) 6 if(feature == None): 7 return value 8 retTree = {} 9 retTree["spIndex"] = feature 10 retTree["spValue"] = value 11 # chooseBsetSplit其实应该一并把mat0和mat1返回,这样这里就不需要再计算了。 12 # 但是后来看了一下代码,返现该函数里面有的返回分支里面是没有mat0和mat1,所以这里再计算一下也是说的通的。 13 lset, rset = bindSplitDataset(dataset, feature, value) 14 retTree["left"] = createTree(lset, leafType, errorType, ops) 15 retTree["right"] = createTree(rset, leafType, errorType, ops) 16 17 return retTree
下面的代码就是获取最佳区分特征和特征值的实现
1 # 寻找最好的区分特征;为了能够找到需要遍历所有的特征,以及所有的特征值,然后以该特征值做分割,获取两个矩阵 2 # 计算两个矩阵的方差,不断选出方差小的作为bestIndex以及bestValue;最后将bestIndex对应的方差和原始矩阵 3 # 方差进行比较,如果发现最佳区分特征对应的两分割矩阵方差明显小,并且两个矩阵的样本数量都不是十分小; 4 # 则说明该特征是OK的 5 6 # 返回的feature信息可能是None,代表该节点就是叶子结点中left/right的值,该函数 7 def chooseBsetSplit(dataset, leafType=regLeaf, errorType=regErr, ops=(1, 4)): 8 # 可容忍的偏差,在程序开始的时候,通过errorType来计算一下dataset的y值的方差和;然后用dataset的方差 9 # 和最好区分度的方差和做减法,如果发现差值比这个tolS要小,那么说明这次指定特征是失败的;理想的差值是要大于tols 10 # 方差一定要比原始数据小到一定程度,这次属性指定才有意义。 11 tolS = ops[0] 12 tolN = ops[1] # 特征划分的样本的阈值,如果一分为二后,任何一个分类样本数少于这个阈值,这次划分就取消 13 # 为什么==1就要退出? 14 if(len(set(dataset[:, -1].T.tolist()[0])) == 1): 15 #print("len(set(dataset[:, -1].T.tolist()[0])) == 1, return None feature") 16 return None, leafType(dataset) 17 m, n = shape(dataset) 18 # 注意这里errorType其实就是参数,这里参数就是一个函数,默认是regErr 19 S =errorType(dataset) 20 # 初始化best* 21 bestS = inf 22 bestIndex = 0 23 bestValue = 0 24 iterate_num = n-1 25 #print("iterate_num: " + str(iterate_num)) 26 # 遍历所有的特征(最后一列是结果,跳过) 27 for featureIndex in range(iterate_num): 28 #print("++++++++++++++++++++++ %d turns +++++++++++++++++++++++" % (featureIndex)) 29 # 遍历该特征的所有特征值 30 for splitValue in set(dataset[:, featureIndex].A.flatten().tolist()): 31 # 在所有训练样本上面(dataset)对于该特征,大于该特征值,小于特征值分别做数据分割,获得两个矩阵 32 mat0, mat1 = bindSplitDataset(dataset, featureIndex, splitValue) 33 # 如果分割的特征矩阵任意一个的样本数<tolN,那么将会跳过该特征的处理,经过分割一定要达到一定的样本数才有意义 34 # 任意一个矩阵的样本数少说明该特征的区分度不高 35 if(shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): 36 #print("shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN; splitValue: %f, shape(mat0)[0]: %d, (shape(mat1)[0]: %d, tolN: %d" % (splitValue, shape(mat0)[0], shape(mat1)[0], tolN)) 37 continue 38 #print("*************** one ok **********************") 39 # 和leafType一样,都是参数类型为函数,计算方差和 40 newS = errorType(mat0) + errorType(mat1) 41 # 如果方差小于bestS,则用当前的方差以及特征信息做替换;到此可以看到目标就是找到区分度高并且方差小的特征,作为最好 42 # 区分特征 43 if(newS < bestS): 44 bestIndex = featureIndex 45 bestS = newS 46 bestValue = splitValue 47 # 如果S值和bestS值之差小于tolS;参见tolS的注释。 48 if(S -bestS) < tolS: 49 #print("(S -bestS) < tolS, return feature NULL, S: %s, bestS: %s, tolS: %s" % ( str(S), str(bestS), str(tolS))) 50 return None, leafType(dataset) 51 mat0, mat1 = bindSplitDataset(dataset, bestIndex, bestValue) 52 # 这里的判断有意义吗?在循环体中其实已经做了这个判断,如果不满足也不会成为bestIndex,bestvalue; 53 if(shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): 54 print("shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN") 55 return None, leafType(dataset) 56 57 return bestIndex, bestValue
posted on 2019-03-10 20:56 张叫兽的技术研究院 阅读(1427) 评论(0) 编辑 收藏 举报