树回归 CART算法
线性回归创建的预测模型需要拟合所有的样本点,在数据拥有众多特征并且特征之间关系十分复杂时,构建全局模型太难,而且,生活中很多问题是非线性的,不可能使用全局线性模型来拟合任何数据。
一种可行的方法是把数据集切分成很多分易建模的数据,然后利用线性回归技术来建模。如果首次切分后仍然难以拟合线性模型就继续切分。这种切分方式下,树结构和回归法就相当有用。
CART算法:分类回归树,既可用于分类也可用于回归。
第三章使用的决策树构建算法是ID3,每次选取当前最佳的特征来分割数据。属于贪心算法,不考虑能否达到全局最优。而且容易造成过拟合、不能直接处理连续型特征,只有事先将连续型特征转换成离散型,才能使用ID3算法。
而使用二元切分法则易于对树构建过程进行调整以处理连续型特征。如果特征值大于给定值就走左子树,小于给定值就走右子树。
CART算法的实现代码:
from numpy import * def loadDataSet(filename): dataMat=[] f=open(filename) for line in f.readlines(): curLine=line.strip().split('\t') floatLine=list(map(float,curLine)) dataMat.append(floatLine) return dataMat def binSplitDataSet(dataSet,feature,value): mat0=dataSet[nonzero(dataSet[:,feature]>value)[0],:] mat1=dataSet[nonzero(dataSet[:,feature]<=value)[0],:] return mat0,mat1 def createTree(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)): feat, val = chooseBestSplit(dataSet, leafType, errType, ops) if feat == None: return val retTree = {} retTree['spInd'] = feat retTree['spVal'] = val lSet, rSet = binSplitDataSet(dataSet, feat, val) retTree['left'] = createTree(lSet, leafType, errType, ops) retTree['right'] = createTree(rSet, leafType, errType, ops) return retTree
chooseBestSplit()函数暂未实现。
将CART算法用于回归:
回归树假设叶子节点是常数值。用平方误差的总值(总方差)来计算连续型数值的混乱程度。总方差等于均方差乘以数据集中样本点的个数。
chooseBestSplit():给定某个误差计算方法,该函数会找到数据集上最佳的二元切分方式。还要确定什么时候停止切分,一旦停止切分就会生成一个叶子节点。所以:用最佳方式切分数据集和生成相应的叶节点。
伪代码:
对每个特征: 对每个特征值: 将数据集切分成两份 计算切分后的误差 如果当前误差小于当前最小误差,将当前切分设定为最佳切分并更新最小误差 返回最佳切分的特征和阈值
切分函数的实现:
def regLeaf(dataSet): #负责生成叶节点,当chooseBestSplit函数确定不再对数据进行切分时,将调用regLeaf函数得到叶节点的模型 return mean(dataSet[:,-1]) #在回归树中,此模型就是目标变量的均值 def regErr(dataSet): # 误差估计函数,计算目标变量的平方误差,需要返回总误差,即为均方误差乘以数据集中样本个数 return var(dataSet[:, -1]) * shape(dataSet)[0] def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)): #ops为用户指定的参数,用于控制函数的停止时机 tolS = ops[0] # 容许的误差下降值 tolN = ops[1] # 切分的最少样本数 if len(set(dataSet[:, -1].T.tolist()[0])) == 1: # 统计不同剩余特征值得数目,如果数目为一,就不需要再切分而直接返回 return None, leafType(dataSet) else: m, n = shape(dataSet) S = errType(dataSet) #误差 bestS = inf #最小误差 bestIndex = 0 bestValue = 0 for featIndex in range(n - 1): # 对所有特征进行遍历,找到最佳切分方式。最佳切分就是使得切分后能达到最低误差的切分 # for splitVal in set(dataSet[:, featIndex]): # 遍历某个特征的所有特征值 for splitVal in set((dataSet[:, featIndex].T.A.tolist())[0]): mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal) # 按照某个特征的某个值将数据切分成两个数据子集 if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): # 如果某个子集行数不大于tolN,也不应该切分 continue newS = errType(mat0) + errType(mat1) # 新误差由切分后的两个数据子集组成的误差 if newS < bestS: # 判断新切分能否降低误差 bestIndex = featIndex bestValue = splitVal bestS = newS if (S - bestS) < tolS: # 如果误差降低不大则退出 return None, leafType(dataSet) mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue) if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): # 如果切分出的数据集很小则退出 return None, leafType(dataSet) return bestIndex, bestValue
regLeaf():负责生成叶节点,即求当前数据集目标值的平均值作为回归预测值。当chooseBestSplit()确定不再对数据进行切分时,将调用regLeaf()函数来得到叶节点的模型。回归树中,该模型是目标变量的均值。
regErr():误差估计函数。在给定数据集上计算目标变量的平方误差。
chooseBestSplit():构建回归树的核心函数。目的是找到数据的最佳二元切分方式。如果找不到好的二元切分,就返回None并同时调用regLeaf()方法来产生叶节点。
运行代码:
if __name__=='__main__': myMat=loadDataSet('ex00.txt') myMat=mat(myMat) result=createTree(myMat) print(result)
输出为:
{'spInd': 0, 'spVal': 0.48813, 'right': -0.04465028571428572, 'left': 1.0180967672413792}
只有两个叶节点,对照下面的散点图可以看出,在数据0.48813左侧的数据,回归预测值为-0.04465,右侧预测值为1.018。
数据集散点图:
因为数据集简单,所以得到的回归树也简单。
更换数据集测试:
if __name__=='__main__': myMat2=loadDataSet('ex2.txt') myMat2=mat(myMat2) myTree = createTree(myMat2, ops=(0, 1)) print(myTree)
输出:
{'spInd': 0, 'spVal': 0.499171, 'right': {'spInd': 0, 'spVal': 0.457563, 'right': {'spInd': 0, 'spVal': 0.455761, 'right': {'spInd': 0, 'spVal': 0.126833, 'right': {'spInd': 0, 'spVal': 0.124723, 'right': {'spInd': 0, 'spVal': 0.085111, 'right': {'spInd': 0, 'spVal': 0.084661, 'right': {'spInd': 0, 'spVal': 0.080061, 'right': {'spInd': 0, 'spVal': 0.068373, 'right': {'spInd': 0, 'spVal': 0.061219, 'right': {'spInd': 0, 'spVal': 0.044737, 'right': {'spInd': 0, 'spVal': 0.028546, 'right': {'spInd': 0, 'spVal': 0.000256, 'right': 9.668106, 'left': -8.377094}, 'left': {'spInd': 0, 'spVal': 0.039914, 'right': 11.220099, 'left': 3.855393}}, 'left': {'spInd': 0, 'spVal': 0.053764, 'right': -13.731698, 'left': {'spInd': 0, 'spVal': 0.055862, 'right': -3.131497, 'left': 6.695567}}}, 'left': -15.160836}, 'left': {'spInd': 0, 'spVal': 0.079632, 'right': 29.420068, 'left': 2.229873}}, 'left': -24.132226}, 'left': 37.820659}, 'left': {'spInd': 0, 'spVal': 0.108801, 'right': {'spInd': 0, 'spVal': 0.10796, 'right': {'spInd': 0, 'spVal': 0.085873, 'right': -10.137104, 'left': -1.293195}, 'left': -16.106164}, 'left': {'spInd': 0, 'spVal': 0.11515, 'right': 13.795828, 'left': -1.402796}}}, 'left': 22.891675}, 'left': {'spInd': 0, 'spVal': 0.130626, 'right': -39.524461, 'left': {'spInd': 0, 'spVal': 0.382037, 'right': {'spInd': 0, 'spVal': 0.335182, 'right': {'spInd': 0, 'spVal': 0.324274, 'right': {'spInd': 0, 'spVal': 0.309133, 'right': {'spInd': 0, 'spVal': 0.131833, 'right': 22.478291, 'left': {'spInd': 0, 'spVal': 0.138619, 'right': -29.087463, 'left': {'spInd': 0, 'spVal': 0.156067, 'right': {'spInd': 0, 'spVal': 0.13988, 'right': 7.336784, 'left': 7.557349}, 'left': {'spInd': 0, 'spVal': 0.166765, 'right': {'spInd': 0, 'spVal': 0.156273, 'right': 0.225886, 'left': {'spInd': 0, 'spVal': 0.164134, 'right': -27.405211, 'left': {'spInd': 0, 'spVal': 0.166431, 'right': -6.512506, 'left': -14.740059}}}, 'left': {'spInd': 0, 'spVal': 0.193282, 'right': {'spInd': 0, 'spVal': 0.176523, 'right': 0.946348, 'left': 18.208423}, 'left': {'spInd': 0, 'spVal': 0.211633, 'right': {'spInd': 0, 'spVal': 0.202161, 'right': {'spInd': 0, 'spVal': 0.199903, 'right': -3.372472, 'left': -1.983889}, 'left': {'spInd': 0, 'spVal': 0.203993, 'right': -22.379119, 'left': {'spInd': 0, 'spVal': 0.206207, 'right': -12.619036, 'left': -8.332207}}}, 'left': {'spInd': 0, 'spVal': 0.228473, 'right': {'spInd': 0, 'spVal': 0.222271, 'right': {'spInd': 0, 'spVal': 0.218321, 'right': {'spInd': 0, 'spVal': 0.217214, 'right': -3.958752, 'left': 1.410768}, 'left': -9.255852}, 'left': {'spInd': 0, 'spVal': 0.2232, 'right': 15.501642, 'left': 19.425158}}, 'left': {'spInd': 0, 'spVal': 0.25807, 'right': {'spInd': 0, 'spVal': 0.228628, 'right': -2.266273, 'left': {'spInd': 0, 'spVal': 0.228751, 'right': -30.812912, 'left': {'spInd': 0, 'spVal': 0.232802, 'right': 1.222318, 'left': -20.425137}}}, 'left': {'spInd': 0, 'spVal': 0.284794, 'right': {'spInd': 0, 'spVal': 0.273863, 'right': {'spInd': 0, 'spVal': 0.264926, 'right': {'spInd': 0, 'spVal': 0.264639, 'right': 2.557923, 'left': 5.280579}, 'left': -9.457556}, 'left': 35.623746}, 'left': {'spInd': 0, 'spVal': 0.300318, 'right': {'spInd': 0, 'spVal': 0.297107, 'right': {'spInd': 0, 'spVal': 0.295993, 'right': {'spInd': 0, 'spVal': 0.290749, 'right': -14.391613, 'left': -14.988279}, 'left': -1.798377}, 'left': -18.051318}, 'left': 8.814725}}}}}}}}}}, 'left': {'spInd': 0, 'spVal': 0.310956, 'right': -49.939516, 'left': {'spInd': 0, 'spVal': 0.318309, 'right': -27.605424, 'left': -13.189243}}}, 'left': {'spInd': 0, 'spVal': 0.32889, 'right': 39.783113, 'left': {'spInd': 0, 'spVal': 0.331364, 'right': -1.290825, 'left': {'spInd': 0, 'spVal': 0.3349, 'right': 18.97665, 'left': 2.768225}}}}, 'left': {'spInd': 0, 'spVal': 0.370042, 'right': {'spInd': 0, 'spVal': 0.35679, 'right': {'spInd': 0, 'spVal': 0.350725, 'right': {'spInd': 0, 'spVal': 0.350065, 'right': {'spInd': 0, 'spVal': 0.342761, 'right': {'spInd': 0, 'spVal': 0.342155, 'right': {'spInd': 0, 'spVal': 0.3417, 'right': -23.547711, 'left': -16.930416}, 'left': -31.584855}, 'left': -1.319852}, 'left': -40.086564}, 'left': {'spInd': 0, 'spVal': 0.351478, 'right': -0.461116, 'left': -19.526539}}, 'left': -32.124495}, 'left': {'spInd': 0, 'spVal': 0.378965, 'right': {'spInd': 0, 'spVal': 0.373501, 'right': -8.228297, 'left': {'spInd': 0, 'spVal': 0.377383, 'right': 5.241196, 'left': 13.583555}}, 'left': -29.007783}}}, 'left': {'spInd': 0, 'spVal': 0.388789, 'right': {'spInd': 0, 'spVal': 0.385021, 'right': 24.816941, 'left': 21.578007}, 'left': {'spInd': 0, 'spVal': 0.437652, 'right': {'spInd': 0, 'spVal': 0.412516, 'right': {'spInd': 0, 'spVal': 0.403228, 'right': {'spInd': 0, 'spVal': 0.391609, 'right': 3.001104, 'left': -1.729244}, 'left': -26.419289}, 'left': {'spInd': 0, 'spVal': 0.418943, 'right': 44.161493, 'left': {'spInd': 0, 'spVal': 0.426711, 'right': -21.594268, 'left': {'spInd': 0, 'spVal': 0.428582, 'right': 15.224266, 'left': 19.745224}}}}, 'left': {'spInd': 0, 'spVal': 0.454312, 'right': {'spInd': 0, 'spVal': 0.446196, 'right': -5.108172, 'left': {'spInd': 0, 'spVal': 0.451087, 'right': -28.724685, 'left': -20.360067}}, 'left': {'spInd': 0, 'spVal': 0.454375, 'right': 3.043912, 'left': 9.841938}}}}}}}, 'left': -34.044555}, 'left': {'spInd': 0, 'spVal': 0.465561, 'right': {'spInd': 0, 'spVal': 0.463241, 'right': 17.171057, 'left': 30.051931}, 'left': {'spInd': 0, 'spVal': 0.467383, 'right': {'spInd': 0, 'spVal': 0.46568, 'right': -23.777531, 'left': -9.712925}, 'left': {'spInd': 0, 'spVal': 0.483803, 'right': 5.224234, 'left': {'spInd': 0, 'spVal': 0.487381, 'right': 27.729263, 'left': {'spInd': 0, 'spVal': 0.487537, 'right': 5.149336, 'left': 11.924204}}}}}}, 'left': {'spInd': 0, 'spVal': 0.729397, 'right': {'spInd': 0, 'spVal': 0.640515, 'right': {'spInd': 0, 'spVal': 0.613004, 'right': {'spInd': 0, 'spVal': 0.606417, 'right': {'spInd': 0, 'spVal': 0.513332, 'right': {'spInd': 0, 'spVal': 0.508548, 'right': {'spInd': 0, 'spVal': 0.508542, 'right': 96.403373, 'left': 93.292829}, 'left': 101.075609}, 'left': {'spInd': 0, 'spVal': 0.533511, 'right': {'spInd': 0, 'spVal': 0.51915, 'right': 116.176162, 'left': {'spInd': 0, 'spVal': 0.531944, 'right': 124.795495, 'left': 129.766743}}, 'left': {'spInd': 0, 'spVal': 0.548539, 'right': {'spInd': 0, 'spVal': 0.546601, 'right': {'spInd': 0, 'spVal': 0.537834, 'right': 90.995536, 'left': {'spInd': 0, 'spVal': 0.543843, 'right': 98.36201, 'left': 96.319043}}, 'left': 83.114502}, 'left': {'spInd': 0, 'spVal': 0.553797, 'right': {'spInd': 0, 'spVal': 0.549814, 'right': 137.267576, 'left': 120.857321}, 'left': {'spInd': 0, 'spVal': 0.560301, 'right': 82.903945, 'left': {'spInd': 0, 'spVal': 0.599142, 'right': {'spInd': 0, 'spVal': 0.589806, 'right': {'spInd': 0, 'spVal': 0.582311, 'right': {'spInd': 0, 'spVal': 0.571214, 'right': {'spInd': 0, 'spVal': 0.569327, 'right': 108.435392, 'left': 114.872056}, 'left': 82.589328}, 'left': {'spInd': 0, 'spVal': 0.585413, 'right': 125.295113, 'left': 98.674874}}, 'left': 130.378529}, 'left': 93.521396}}}}}}, 'left': 168.180746}, 'left': {'spInd': 0, 'spVal': 0.623909, 'right': {'spInd': 0, 'spVal': 0.618868, 'right': 76.917665, 'left': 87.181863}, 'left': {'spInd': 0, 'spVal': 0.628061, 'right': {'spInd': 0, 'spVal': 0.624827, 'right': 105.970743, 'left': 117.628346}, 'left': {'spInd': 0, 'spVal': 0.637999, 'right': {'spInd': 0, 'spVal': 0.632691, 'right': 93.645293, 'left': 91.656617}, 'left': 82.713621}}}}, 'left': {'spInd': 0, 'spVal': 0.642373, 'right': 140.613941, 'left': {'spInd': 0, 'spVal': 0.642707, 'right': 82.500766, 'left': {'spInd': 0, 'spVal': 0.665329, 'right': {'spInd': 0, 'spVal': 0.661073, 'right': {'spInd': 0, 'spVal': 0.652462, 'right': 112.715799, 'left': 115.687524}, 'left': 121.980607}, 'left': {'spInd': 0, 'spVal': 0.706961, 'right': {'spInd': 0, 'spVal': 0.698472, 'right': {'spInd': 0, 'spVal': 0.689099, 'right': {'spInd': 0, 'spVal': 0.666452, 'right': {'spInd': 0, 'spVal': 0.665652, 'right': 105.547997, 'left': 120.014736}, 'left': {'spInd': 0, 'spVal': 0.667851, 'right': 92.449664, 'left': {'spInd': 0, 'spVal': 0.680486, 'right': 110.367074, 'left': 112.378209}}}, 'left': 120.521925}, 'left': {'spInd': 0, 'spVal': 0.69892, 'right': 92.470636, 'left': {'spInd': 0, 'spVal': 0.699873, 'right': 115.586605, 'left': {'spInd': 0, 'spVal': 0.70639, 'right': 105.062147, 'left': 106.180427}}}}, 'left': {'spInd': 0, 'spVal': 0.70889, 'right': 135.416767, 'left': {'spInd': 0, 'spVal': 0.716211, 'right': {'spInd': 0, 'spVal': 0.710234, 'right': 108.553919, 'left': 103.345308}, 'left': 110.90283}}}}}}}, 'left': {'spInd': 0, 'spVal': 0.952833, 'right': {'spInd': 0, 'spVal': 0.759504, 'right': {'spInd': 0, 'spVal': 0.740859, 'right': {'spInd': 0, 'spVal': 0.731636, 'right': 73.912028, 'left': 93.773929}, 'left': {'spInd': 0, 'spVal': 0.757527, 'right': 63.549854, 'left': 81.106762}}, 'left': {'spInd': 0, 'spVal': 0.763328, 'right': 115.199195, 'left': {'spInd': 0, 'spVal': 0.769043, 'right': 64.041941, 'left': {'spInd': 0, 'spVal': 0.790312, 'right': {'spInd': 0, 'spVal': 0.786865, 'right': {'spInd': 0, 'spVal': 0.785574, 'right': {'spInd': 0, 'spVal': 0.777582, 'right': 100.838446, 'left': 107.024467}, 'left': 100.598825}, 'left': {'spInd': 0, 'spVal': 0.787755, 'right': 118.642009, 'left': 110.15973}}, 'left': {'spInd': 0, 'spVal': 0.806158, 'right': {'spInd': 0, 'spVal': 0.799873, 'right': {'spInd': 0, 'spVal': 0.798198, 'right': 76.853728, 'left': 91.368473}, 'left': 62.877698}, 'left': {'spInd': 0, 'spVal': 0.815215, 'right': {'spInd': 0, 'spVal': 0.811602, 'right': {'spInd': 0, 'spVal': 0.811363, 'right': 112.981216, 'left': 99.841379}, 'left': 118.319942}, 'left': {'spInd': 0, 'spVal': 0.833026, 'right': {'spInd': 0, 'spVal': 0.823848, 'right': {'spInd': 0, 'spVal': 0.819722, 'right': 70.054508, 'left': 59.342323}, 'left': 76.723835}, 'left': {'spInd': 0, 'spVal': 0.841547, 'right': {'spInd': 0, 'spVal': 0.838587, 'right': 134.089674, 'left': 115.669032}, 'left': {'spInd': 0, 'spVal': 0.841625, 'right': 60.552308, 'left': {'spInd': 0, 'spVal': 0.944221, 'right': {'spInd': 0, 'spVal': 0.85497, 'right': {'spInd': 0, 'spVal': 0.84294, 'right': 95.893131, 'left': {'spInd': 0, 'spVal': 0.847219, 'right': 76.240984, 'left': 89.20993}}, 'left': {'spInd': 0, 'spVal': 0.936524, 'right': {'spInd': 0, 'spVal': 0.934853, 'right': {'spInd': 0, 'spVal': 0.925782, 'right': {'spInd': 0, 'spVal': 0.910975, 'right': {'spInd': 0, 'spVal': 0.901444, 'right': {'spInd': 0, 'spVal': 0.901421, 'right': {'spInd': 0, 'spVal': 0.892999, 'right': {'spInd': 0, 'spVal': 0.888426, 'right': {'spInd': 0, 'spVal': 0.872199, 'right': {'spInd': 0, 'spVal': 0.866451, 'right': {'spInd': 0, 'spVal': 0.856421, 'right': 107.166848, 'left': 94.402102}, 'left': 111.552716}, 'left': {'spInd': 0, 'spVal': 0.883615, 'right': {'spInd': 0, 'spVal': 0.872883, 'right': 95.887712, 'left': 95.348184}, 'left': {'spInd': 0, 'spVal': 0.885676, 'right': 108.045948, 'left': 94.896354}}}, 'left': 82.436686}, 'left': {'spInd': 0, 'spVal': 0.900699, 'right': {'spInd': 0, 'spVal': 0.896683, 'right': 107.00162, 'left': 109.188248}, 'left': 100.133819}}, 'left': 87.300625}, 'left': {'spInd': 0, 'spVal': 0.908629, 'right': 118.513475, 'left': 106.814667}}, 'left': {'spInd': 0, 'spVal': 0.912161, 'right': 85.005351, 'left': {'spInd': 0, 'spVal': 0.915263, 'right': 96.71761, 'left': 92.074619}}}, 'left': 115.753994}, 'left': 65.548418}, 'left': {'spInd': 0, 'spVal': 0.937766, 'right': 119.949824, 'left': 100.120253}}}, 'left': {'spInd': 0, 'spVal': 0.948822, 'right': 69.318649, 'left': {'spInd': 0, 'spVal': 0.949198, 'right': 105.752508, 'left': {'spInd': 0, 'spVal': 0.952377, 'right': 73.520802, 'left': 100.649591}}}}}}}}}}}}}, 'left': {'spInd': 0, 'spVal': 0.965969, 'right': {'spInd': 0, 'spVal': 0.956951, 'right': {'spInd': 0, 'spVal': 0.953902, 'right': 130.92648, 'left': {'spInd': 0, 'spVal': 0.954711, 'right': 100.935789, 'left': 82.016541}}, 'left': {'spInd': 0, 'spVal': 0.958512, 'right': 135.837013, 'left': {'spInd': 0, 'spVal': 0.960398, 'right': 123.559747, 'left': 112.386764}}}, 'left': {'spInd': 0, 'spVal': 0.968621, 'right': 98.648346, 'left': 86.399637}}}}}
散点图:
得到的树很复杂,改变ops元组的值:
if __name__=='__main__': myMat2 = loadDataSet('ex2.txt') myMat2 = mat(myMat2) myTree = createTree(myMat2, ops=(10000, 4)) print(myTree)
输出:
{'spInd': 0, 'spVal': 0.499171, 'left': 101.35815937735848, 'right': -2.637719329787234}
也可以得到仅有两个叶节点的树。
树剪枝:
一棵树如果节点过多,表明该模型可能对数据进行了过拟合。通过降低决策树的复杂度来避免过拟合的过程称为“剪枝”。
在函数chooseBestSplit()中的提前终止条件,实际上是“预剪枝”操作,预剪枝操作对于参数ops元组非常敏感,难以获得有效的回归树。
后剪枝:利用测试集对数进行剪枝。由于不需要用户指定参数,后剪枝是一种更理想化的剪枝方法。
首先将数据集划分为训练集和测试集。先使用训练集构建出一棵足够复杂的树便于剪枝。然后从上到下找到叶节点,用测试集来判断这些叶节点合并能不能降低测试误差,如果可以的话就合并。
伪代码如下:
基于已有的树切分测试数据: 如果存在任一子集是一棵树,则在该子集递归剪枝过程 计算将当前两个叶子节点合并后的误差 计算不合并的误差 如果合并会降低误差则合并
回归树剪枝函数prune():
def isTree(obj): # 测试输入变量是否是一棵树,返回布尔型的结果,用于判断当前处理的节点是否是叶节点 return (type(obj).__name__ == "dict") def getMean(tree): # 递归函数,从上到下遍历树直到叶节点为止。如果找到两个叶节点则计算它们的平均值。该函数对树进行塌陷处理 if isTree(tree["right"]): tree["right"] = getMean(tree["right"]) if isTree(tree["left"]): tree["left"] = getMean(tree["left"]) return (tree["left"] + tree["right"]) / 2.0 def prune(tree, testData): #参数:待剪枝的树与剪枝所需的测试数据 if shape(testData)[0] == 0: #没有测试数据则对树进行塌陷处理 return getMean(tree) if (isTree(tree['right']) or isTree(tree['left'])): # lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal']) if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet) if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet) if not isTree(tree['left']) and not isTree(tree['right']): lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal']) errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + sum(power(rSet[:, -1] - tree['right'], 2)) treeMean = (tree['left'] + tree['right']) / 2.0 errorMerge = sum(power(testData[:, -1] - treeMean, 2)) if errorMerge < errorNoMerge: print("融合") return treeMean else: return tree else: return tree
isTree():测试输入变量是否是一棵树,返回布尔值的结果。用于判断当前处理的节点是不是叶子节点。
getMean():递归函数,从上到下遍历树直到叶节点。如果找到两个叶节点就返回其平均值。该函数对树进行塌陷处理。
prune():参数为待剪枝的树和剪枝所需的测试数据集。
测试:
if __name__=='__main__': myMat2=loadDataSet('ex2.txt') myMat2=mat(myMat2) myTree = createTree(myMat2, ops=(0, 1)) myDat2Test = loadDataSet("ex2test.txt") myMat2Test = mat(myDat2Test) result=prune(myTree, myMat2Test) print(result)
输出:
融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 融合 {'left': {'left': {'left': {'left': 92.5239915, 'spInd': 0, 'spVal': 0.965969, 'right': {'left': {'left': {'left': 112.386764, 'spInd': 0, 'spVal': 0.960398, 'right': 123.559747}, 'spInd': 0, 'spVal': 0.958512, 'right': 135.837013}, 'spInd': 0, 'spVal': 0.956951, 'right': 111.2013225}}, 'spInd': 0, 'spVal': 0.952833, 'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 96.41885225, 'spInd': 0, 'spVal': 0.948822, 'right': 69.318649}, 'spInd': 0, 'spVal': 0.944221, 'right': {'left': {'left': 110.03503850000001, 'spInd': 0, 'spVal': 0.936524, 'right': {'left': 65.548418, 'spInd': 0, 'spVal': 0.934853, 'right': {'left': 115.753994, 'spInd': 0, 'spVal': 0.925782, 'right': {'left': {'left': 94.3961145, 'spInd': 0, 'spVal': 0.912161, 'right': 85.005351}, 'spInd': 0, 'spVal': 0.910975, 'right': {'left': {'left': 106.814667, 'spInd': 0, 'spVal': 0.908629, 'right': 118.513475}, 'spInd': 0, 'spVal': 0.901444, 'right': {'left': 87.300625, 'spInd': 0, 'spVal': 0.901421, 'right': {'left': {'left': 100.133819, 'spInd': 0, 'spVal': 0.900699, 'right': 108.094934}, 'spInd': 0, 'spVal': 0.892999, 'right': {'left': 82.436686, 'spInd': 0, 'spVal': 0.888426, 'right': {'left': 98.54454949999999, 'spInd': 0, 'spVal': 0.872199, 'right': 106.16859550000001}}}}}}}}}, 'spInd': 0, 'spVal': 0.85497, 'right': {'left': {'left': 89.20993, 'spInd': 0, 'spVal': 0.847219, 'right': 76.240984}, 'spInd': 0, 'spVal': 0.84294, 'right': 95.893131}}}, 'spInd': 0, 'spVal': 0.841625, 'right': 60.552308}, 'spInd': 0, 'spVal': 0.841547, 'right': 124.87935300000001}, 'spInd': 0, 'spVal': 0.833026, 'right': {'left': 76.723835, 'spInd': 0, 'spVal': 0.823848, 'right': {'left': 59.342323, 'spInd': 0, 'spVal': 0.819722, 'right': 70.054508}}}, 'spInd': 0, 'spVal': 0.815215, 'right': {'left': 118.319942, 'spInd': 0, 'spVal': 0.811602, 'right': {'left': 99.841379, 'spInd': 0, 'spVal': 0.811363, 'right': 112.981216}}}, 'spInd': 0, 'spVal': 0.806158, 'right': 73.49439925}, 'spInd': 0, 'spVal': 0.790312, 'right': {'left': 114.4008695, 'spInd': 0, 'spVal': 0.786865, 'right': 102.26514075}}, 'spInd': 0, 'spVal': 0.769043, 'right': 64.041941}, 'spInd': 0, 'spVal': 0.763328, 'right': 115.199195}, 'spInd': 0, 'spVal': 0.759504, 'right': 78.08564325}}, 'spInd': 0, 'spVal': 0.729397, 'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 110.90283, 'spInd': 0, 'spVal': 0.716211, 'right': {'left': 103.345308, 'spInd': 0, 'spVal': 0.710234, 'right': 108.553919}}, 'spInd': 0, 'spVal': 0.70889, 'right': 135.416767}, 'spInd': 0, 'spVal': 0.706961, 'right': {'left': {'left': {'left': {'left': 106.180427, 'spInd': 0, 'spVal': 0.70639, 'right': 105.062147}, 'spInd': 0, 'spVal': 0.699873, 'right': 115.586605}, 'spInd': 0, 'spVal': 0.69892, 'right': 92.470636}, 'spInd': 0, 'spVal': 0.698472, 'right': {'left': 120.521925, 'spInd': 0, 'spVal': 0.689099, 'right': {'left': 101.91115275, 'spInd': 0, 'spVal': 0.666452, 'right': 112.78136649999999}}}}, 'spInd': 0, 'spVal': 0.665329, 'right': {'left': 121.980607, 'spInd': 0, 'spVal': 0.661073, 'right': {'left': 115.687524, 'spInd': 0, 'spVal': 0.652462, 'right': 112.715799}}}, 'spInd': 0, 'spVal': 0.642707, 'right': 82.500766}, 'spInd': 0, 'spVal': 0.642373, 'right': 140.613941}, 'spInd': 0, 'spVal': 0.640515, 'right': {'left': {'left': {'left': {'left': 82.713621, 'spInd': 0, 'spVal': 0.637999, 'right': {'left': 91.656617, 'spInd': 0, 'spVal': 0.632691, 'right': 93.645293}}, 'spInd': 0, 'spVal': 0.628061, 'right': {'left': 117.628346, 'spInd': 0, 'spVal': 0.624827, 'right': 105.970743}}, 'spInd': 0, 'spVal': 0.623909, 'right': 82.04976400000001}, 'spInd': 0, 'spVal': 0.613004, 'right': {'left': 168.180746, 'spInd': 0, 'spVal': 0.606417, 'right': {'left': {'left': {'left': {'left': {'left': {'left': 93.521396, 'spInd': 0, 'spVal': 0.599142, 'right': {'left': 130.378529, 'spInd': 0, 'spVal': 0.589806, 'right': {'left': 111.9849935, 'spInd': 0, 'spVal': 0.582311, 'right': {'left': 82.589328, 'spInd': 0, 'spVal': 0.571214, 'right': {'left': 114.872056, 'spInd': 0, 'spVal': 0.569327, 'right': 108.435392}}}}}, 'spInd': 0, 'spVal': 0.560301, 'right': 82.903945}, 'spInd': 0, 'spVal': 0.553797, 'right': 129.0624485}, 'spInd': 0, 'spVal': 0.548539, 'right': {'left': 83.114502, 'spInd': 0, 'spVal': 0.546601, 'right': {'left': 97.3405265, 'spInd': 0, 'spVal': 0.537834, 'right': 90.995536}}}, 'spInd': 0, 'spVal': 0.533511, 'right': {'left': {'left': 129.766743, 'spInd': 0, 'spVal': 0.531944, 'right': 124.795495}, 'spInd': 0, 'spVal': 0.51915, 'right': 116.176162}}, 'spInd': 0, 'spVal': 0.513332, 'right': {'left': 101.075609, 'spInd': 0, 'spVal': 0.508548, 'right': {'left': 93.292829, 'spInd': 0, 'spVal': 0.508542, 'right': 96.403373}}}}}}}, 'spInd': 0, 'spVal': 0.499171, 'right': {'left': {'left': {'left': {'left': {'left': 8.53677, 'spInd': 0, 'spVal': 0.487381, 'right': 27.729263}, 'spInd': 0, 'spVal': 0.483803, 'right': 5.224234}, 'spInd': 0, 'spVal': 0.467383, 'right': {'left': -9.712925, 'spInd': 0, 'spVal': 0.46568, 'right': -23.777531}}, 'spInd': 0, 'spVal': 0.465561, 'right': {'left': 30.051931, 'spInd': 0, 'spVal': 0.463241, 'right': 17.171057}}, 'spInd': 0, 'spVal': 0.457563, 'right': {'left': -34.044555, 'spInd': 0, 'spVal': 0.455761, 'right': {'left': {'left': {'left': {'left': {'left': -4.1911745, 'spInd': 0, 'spVal': 0.437652, 'right': {'left': {'left': {'left': {'left': 19.745224, 'spInd': 0, 'spVal': 0.428582, 'right': 15.224266}, 'spInd': 0, 'spVal': 0.426711, 'right': -21.594268}, 'spInd': 0, 'spVal': 0.418943, 'right': 44.161493}, 'spInd': 0, 'spVal': 0.412516, 'right': {'left': -26.419289, 'spInd': 0, 'spVal': 0.403228, 'right': 0.6359300000000001}}}, 'spInd': 0, 'spVal': 0.388789, 'right': 23.197474}, 'spInd': 0, 'spVal': 0.382037, 'right': {'left': {'left': {'left': -29.007783, 'spInd': 0, 'spVal': 0.378965, 'right': {'left': {'left': 13.583555, 'spInd': 0, 'spVal': 0.377383, 'right': 5.241196}, 'spInd': 0, 'spVal': 0.373501, 'right': -8.228297}}, 'spInd': 0, 'spVal': 0.370042, 'right': {'left': -32.124495, 'spInd': 0, 'spVal': 0.35679, 'right': {'left': -9.9938275, 'spInd': 0, 'spVal': 0.350725, 'right': -26.851234812500003}}}, 'spInd': 0, 'spVal': 0.335182, 'right': {'left': 22.286959625, 'spInd': 0, 'spVal': 0.324274, 'right': {'left': {'left': -20.3973335, 'spInd': 0, 'spVal': 0.310956, 'right': -49.939516}, 'spInd': 0, 'spVal': 0.309133, 'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 8.814725, 'spInd': 0, 'spVal': 0.300318, 'right': {'left': -18.051318, 'spInd': 0, 'spVal': 0.297107, 'right': {'left': -1.798377, 'spInd': 0, 'spVal': 0.295993, 'right': {'left': -14.988279, 'spInd': 0, 'spVal': 0.290749, 'right': -14.391613}}}}, 'spInd': 0, 'spVal': 0.284794, 'right': {'left': 35.623746, 'spInd': 0, 'spVal': 0.273863, 'right': {'left': -9.457556, 'spInd': 0, 'spVal': 0.264926, 'right': {'left': 5.280579, 'spInd': 0, 'spVal': 0.264639, 'right': 2.557923}}}}, 'spInd': 0, 'spVal': 0.25807, 'right': {'left': {'left': -9.601409499999999, 'spInd': 0, 'spVal': 0.228751, 'right': -30.812912}, 'spInd': 0, 'spVal': 0.228628, 'right': -2.266273}}, 'spInd': 0, 'spVal': 0.228473, 'right': 6.099239}, 'spInd': 0, 'spVal': 0.211633, 'right': {'left': -16.42737025, 'spInd': 0, 'spVal': 0.202161, 'right': -2.6781805}}, 'spInd': 0, 'spVal': 0.193282, 'right': 9.5773855}, 'spInd': 0, 'spVal': 0.166765, 'right': {'left': {'left': {'left': -14.740059, 'spInd': 0, 'spVal': 0.166431, 'right': -6.512506}, 'spInd': 0, 'spVal': 0.164134, 'right': -27.405211}, 'spInd': 0, 'spVal': 0.156273, 'right': 0.225886}}, 'spInd': 0, 'spVal': 0.156067, 'right': {'left': 7.557349, 'spInd': 0, 'spVal': 0.13988, 'right': 7.336784}}, 'spInd': 0, 'spVal': 0.138619, 'right': -29.087463}, 'spInd': 0, 'spVal': 0.131833, 'right': 22.478291}}}}}, 'spInd': 0, 'spVal': 0.130626, 'right': -39.524461}, 'spInd': 0, 'spVal': 0.126833, 'right': {'left': 22.891675, 'spInd': 0, 'spVal': 0.124723, 'right': {'left': {'left': 6.196516, 'spInd': 0, 'spVal': 0.108801, 'right': {'left': -16.106164, 'spInd': 0, 'spVal': 0.10796, 'right': {'left': -1.293195, 'spInd': 0, 'spVal': 0.085873, 'right': -10.137104}}}, 'spInd': 0, 'spVal': 0.085111, 'right': {'left': 37.820659, 'spInd': 0, 'spVal': 0.084661, 'right': {'left': -24.132226, 'spInd': 0, 'spVal': 0.080061, 'right': {'left': 15.824970500000001, 'spInd': 0, 'spVal': 0.068373, 'right': {'left': -15.160836, 'spInd': 0, 'spVal': 0.061219, 'right': {'left': {'left': {'left': 6.695567, 'spInd': 0, 'spVal': 0.055862, 'right': -3.131497}, 'spInd': 0, 'spVal': 0.053764, 'right': -13.731698}, 'spInd': 0, 'spVal': 0.044737, 'right': 4.091626}}}}}}}}}}}
虽然合并了很多叶节点,但剪枝后的树没有像预期的那样剪枝成两部分。说明后剪枝可能不如预剪枝有效。可以同时使用两种剪枝方式。
模型树:把叶子节点设定为分段线性函数。利用数生成算法对数据切分,且每份切分数据容易被线性模型表示。该算法的关键在于误差的计算。
对于给定的数据集,应该先用线性的模型对它拟合,然后计算真是的目标值与模型预测值之间的差值,再将这些差值的平方求和就得到了所需要的误差。
模型树的叶节点生成函数:
def linearSolve(dataSet): m, n = shape(dataSet) X = mat(ones((m, n))) #第一列仍为1 Y = mat(ones((m, 1))) X[:, 1:n] = dataSet[:, 0:n - 1] # print('X:',X) Y = dataSet[:, -1] # 将X,Y中的数据格式化 # print('Y:',Y) xTx = X.T * X if linalg.det(xTx) == 0.0: raise NameError("此矩阵不可逆。") # ws = linalg.pinv(xTx) * (X.T * Y) ws = xTx.I * (X.T * Y) return ws, X, Y def modelLeaf(dataSet): # 当数据不再需要切分的时候它负责生成叶节点模型 ws, X, Y = linearSolve(dataSet) return ws def modelErr(dataSet): ws, X, Y = linearSolve(dataSet) yHat = X * ws return sum(power(Y - yHat, 2))
数据集散点图如下:
测试:
myMat=mat(loadDataSet('exp2.txt')) plotPoint(myMat) myTree=createTree(myMat,modelLeaf,modelErr,(1,10)) print(myTree)
输出结果:
{'spInd': 0, 'spVal': 0.285477, 'right': matrix([[3.46877936], [1.18521743]]), 'left': matrix([[1.69855694e-03], [1.19647739e+01]])}
将数据集从x=0.285477分开,分别用两段线性模型来拟合。
树回归与标准回归的比较:相关系数
用树回归进行预测的代码:包括回归树和模型树两种树
def regTreeEval(model, inDat): #回归树效果评估 return float(model) def modelTreeEval(model, inDat): #模型树效果评估 n = shape(inDat)[1] X = mat(ones((1, n + 1))) X[:, 1:n + 1] = inDat return float(X * model) def treeForeCast(tree, inData, modelEval=regTreeEval): if not isTree(tree): return modelEval(tree, inData) # 如果输入单个数据或行向量,返回一个浮点值 else: if inData[tree["spInd"]] > tree["spVal"]: if isTree(tree["left"]): return treeForeCast(tree["left"], inData, modelEval) else: return modelEval(tree["left"], inData) else: if isTree(tree["right"]): return treeForeCast(tree["right"], inData, modelEval) else: return modelEval(tree["right"], inData) def createForeCast(tree, testData, modelEval=regTreeEval): #测试不同回归树的效果 m = len(testData) yHat = mat(zeros((m, 1))) for i in range(m): yHat[i, 0] = treeForeCast(tree, mat(testData[i]), modelEval) # 多次调用treeForeCast函数,将结果以列的形式放到yHat变量中 return yHat
因为代码中已经含有标准线性回归函数(linearSolve),所以不必重新写其生成代码。
测试:
if __name__=='__main__': trainMat = mat(loadDataSet("bikeSpeedVsIq_train.txt")) testMat = mat(loadDataSet("bikeSpeedVsIq_test.txt")) myTree = createTree(trainMat, ops=(1, 20)) yHat = createForeCast(myTree, testMat[:, 0]) print("回归树的相关系数:", corrcoef(yHat, testMat[:, -1], rowvar=0)[0, 1]) myTree = createTree(trainMat, modelLeaf, modelErr, (1, 20)) yHat = createForeCast(myTree, testMat[:, 0], modelTreeEval) print("模型树的相关系数:", corrcoef(yHat, testMat[:, -1], rowvar=0)[0, 1]) ws, X, Y = linearSolve(trainMat) print("线性回归系数:", ws) for i in range(shape(testMat)[0]): yHat[i] = testMat[i, 0] * ws[1, 0] + ws[0, 0] print("线性回归模型的相关系数:", corrcoef(yHat, testMat[:, -1], rowvar=0)[0, 1])
输出:
回归树的相关系数: 0.964085231822215 模型树的相关系数: 0.9760412191380629 线性回归系数: [[37.58916794] [ 6.18978355]] 线性回归模型的相关系数: 0.9434684235674766
相关系数越接近1越好,所以,模型树>回归树>标准线性回归。