《机器学习实战》笔记——树回归
线性回归的缺陷:
创建模型是需要拟合所有的样本(除了局部加权线性回归),当数据特征多且关系复杂时,显得太笨拙
树回归:
将数据集分成多分易建模的数据,然后在这些易于建模的小数据集上利用线性回归建模。树回归采用的是二元划分法,所以只可能产生二叉树。
CART算法
全称:classification and regression trees,分类回归树。在每个叶节点上使用各自的均值做预测,每个叶节点上包含单个值。建树期间,为了防止树的过拟合,需要使用到树剪枝技术。
模型树算法:
在每个叶节点上都构建出一个线性模型,即每个叶节点上包含一个线性方程。建树期间,需要调参,所以还回介绍Python中的Tkinter模块建立GUI。
1 # _*_ coding:utf-8 _*_ 2 3 # 9-1 CART算法的代码实现 4 from numpy import * 5 6 def loadDataSet(fileName): 7 dataMat = [] 8 fr = open(fileName) 9 for line in fr.readlines(): 10 curLine = line.strip().split('\t') 11 fltLine = map(float, curLine) # 每行映射成浮点数 12 dataMat.append(fltLine) 13 return dataMat 14 15 # 大于value的数据集划分到左树,小于等于value的数据集划分到右树 16 def binSplitDataSet(dataSet, feature, value): 17 mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :] 18 mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :] 19 return mat0, mat1 20 21 22 def regLeaf(dataSet): 23 return mean(dataSet[:, -1]) 24 25 26 def regErr(dataSet): 27 return var(dataSet[:, -1]) * shape(dataSet)[0] 28 29 30 31 # 9-2 回归树的划分函数 32 # leafType是对创建叶节点的函数的引用 33 # errType是对总方差计算函数的引用 34 # ops是一个用户定义的参数构成的元祖,用于控制停止的时机,第一个值是容许的误差下降值,第二个是切分的最少样本数 35 def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)): 36 tolS = ops[0] 37 tolN = ops[1] 38 if len(set(dataSet[:, -1].T.tolist()[0])) == 1: # 勘误,书上多加了[0] 39 return None, leafType(dataSet) 40 m, n = shape(dataSet) 41 S = errType(dataSet) 42 bestS = inf 43 bestIndex = 0 44 bestValue = 0 45 for featIndex in range(n - 1): 46 for splitVal in set((dataSet[:, featIndex].T.A.tolist())[0]): 47 mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal) 48 if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue 49 newS = errType(mat0) + errType(mat1) 50 if newS < bestS: 51 bestIndex = featIndex 52 bestValue = splitVal 53 bestS = newS 54 if (S - bestS) < tolS: 55 return None, leafType(dataSet) 56 mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue) 57 if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): 58 return None, leafType(dataSet) 59 return bestIndex, bestValue # 返回最好的特征和特征划分的依据值 60 61 # 建CART树 62 # step1 通过chooseBestSplit函数找到最佳划分特征feat以及划分用的标准val 63 # step2 若feat为空,证明已经到了叶子节点,返回val;若非空,则递归创建左右子树,返回存储左右子树的字典引用 64 def createTree(dataSet, leafType=regLeaf, errType=regErr,ops=(1, 4)): 65 feat, val = chooseBestSplit(dataSet, leafType, errType, ops) 66 if feat == None: return val 67 retTree = {} 68 retTree['spInd'] = feat 69 retTree['spVal'] = val 70 lSet, rSet = binSplitDataSet(dataSet, feat, val) 71 retTree['left'] = createTree(lSet, leafType, errType, ops) 72 retTree['right'] = createTree(rSet, leafType, errType, ops) 73 return retTree 74 75 76 def isTree(obj): 77 return (type(obj).__name__ == 'dict') 78 79 80 def getMean(tree): 81 if isTree(tree['right']): tree['right'] = getMean(tree['right']) 82 if isTree(tree['left']): tree['left'] = getMean(tree['left']) 83 return (tree['left'] + tree['right']) / 2.0 84 85 # 后剪枝(建树之后进行) 86 # 准备工作:用训练集建树,构建的树要足够大,方便剪枝 87 # 步骤:如果左右树任意一棵是树,则在该子集递归剪枝过程;否则,计算并比较两个叶节点合并后的误差和不合并的误差,若误差变小,则合并 88 def prune(tree, testData): 89 if shape(testData)[0] == 0: return getMean(tree) 90 if (isTree(tree['right']) or isTree(tree['left'])): 91 lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal']) 92 if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet) 93 if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet) 94 if not isTree(tree['left']) and not isTree(tree['right']): 95 lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal']) 96 errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + \ 97 sum(power(rSet[:, -1] - tree['right'], 2)) 98 treeMean = (tree['left'] + tree['right']) / 2.0 99 errorMerge = sum(power(testData[:, -1] - treeMean, 2)) 100 if errorMerge < errorNoMerge: 101 print ("merging") 102 return treeMean 103 else: 104 return tree 105 else: 106 return tree 107 108 # 9-4 模型树的叶节点生成函数 109 def linearSolve(dataSet): 110 m,n = shape(dataSet) 111 X = mat(ones((m,n))) 112 Y = mat(ones((m,1))) 113 X[:,1:n] = dataSet[:,0:n-1] 114 Y = dataSet[:,-1] 115 xTx = X.T * X 116 if linalg.det(xTx) == 0.0: 117 raise NameError('This matrix is singular, cannot do inverse,\ntry increasing the second value of ops') 118 # 当程序出现错误,python会自动引发异常,也可以通过raise显示地引发异常。一旦执行了raise语句,raise后面的语句将不能执行。 119 ws = xTx.I * (X.T * Y) 120 return ws, X, Y 121 122 123 def modelLeaf(dataSet): 124 ws, X, Y = linearSolve(dataSet) 125 return ws 126 127 128 def modelErr(dataSet): 129 ws, X, Y = linearSolve(dataSet) 130 yHat = X * ws 131 return sum(power(Y - yHat, 2)) 132 133 myMat2 = mat(loadDataSet('exp2.txt')) 134 modelTrees = createTree(myMat2,modelLeaf,modelErr,(1,10)) 135 print modelTrees