《机器学习实战》笔记——树回归

线性回归的缺陷:

  创建模型是需要拟合所有的样本(除了局部加权线性回归),当数据特征多且关系复杂时,显得太笨拙

树回归:

  将数据集分成多分易建模的数据,然后在这些易于建模的小数据集上利用线性回归建模。树回归采用的是二元划分法,所以只可能产生二叉树。

CART算法

  全称:classification and regression trees,分类回归树。在每个叶节点上使用各自的均值做预测,每个叶节点上包含单个值。建树期间,为了防止树的过拟合,需要使用到树剪枝技术

模型树算法:

  在每个叶节点上都构建出一个线性模型,即每个叶节点上包含一个线性方程。建树期间,需要调参,所以还回介绍Python中的Tkinter模块建立GUI。

  1 # _*_ coding:utf-8 _*_
  2 
  3 # 9-1 CART算法的代码实现
  4 from numpy import *
  5 
  6 def loadDataSet(fileName):
  7     dataMat = []
  8     fr = open(fileName)
  9     for line in fr.readlines():
 10         curLine = line.strip().split('\t')
 11         fltLine = map(float, curLine)  # 每行映射成浮点数
 12         dataMat.append(fltLine)
 13     return dataMat
 14 
 15 # 大于value的数据集划分到左树,小于等于value的数据集划分到右树
 16 def binSplitDataSet(dataSet, feature, value):
 17     mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :]
 18     mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]
 19     return mat0, mat1
 20 
 21 
 22 def regLeaf(dataSet):
 23     return mean(dataSet[:, -1])
 24 
 25 
 26 def regErr(dataSet):
 27     return var(dataSet[:, -1]) * shape(dataSet)[0]
 28 
 29 
 30 
 31 # 9-2 回归树的划分函数
 32 # leafType是对创建叶节点的函数的引用
 33 # errType是对总方差计算函数的引用
 34 # ops是一个用户定义的参数构成的元祖,用于控制停止的时机,第一个值是容许的误差下降值,第二个是切分的最少样本数
 35 def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
 36     tolS = ops[0]
 37     tolN = ops[1]
 38     if len(set(dataSet[:, -1].T.tolist()[0])) == 1:    # 勘误,书上多加了[0]
 39         return None, leafType(dataSet)
 40     m, n = shape(dataSet)
 41     S = errType(dataSet)
 42     bestS = inf
 43     bestIndex = 0
 44     bestValue = 0
 45     for featIndex in range(n - 1):
 46         for splitVal in set((dataSet[:, featIndex].T.A.tolist())[0]):
 47             mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
 48             if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
 49             newS = errType(mat0) + errType(mat1)
 50             if newS < bestS:
 51                 bestIndex = featIndex
 52                 bestValue = splitVal
 53                 bestS = newS
 54     if (S - bestS) < tolS:
 55         return None, leafType(dataSet)
 56     mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
 57     if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
 58         return None, leafType(dataSet)
 59     return bestIndex, bestValue  # 返回最好的特征和特征划分的依据值
 60 
 61 # 建CART树
 62 # step1 通过chooseBestSplit函数找到最佳划分特征feat以及划分用的标准val
 63 # step2 若feat为空,证明已经到了叶子节点,返回val;若非空,则递归创建左右子树,返回存储左右子树的字典引用
 64 def createTree(dataSet, leafType=regLeaf, errType=regErr,ops=(1, 4)):
 65     feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
 66     if feat == None: return val
 67     retTree = {}
 68     retTree['spInd'] = feat
 69     retTree['spVal'] = val
 70     lSet, rSet = binSplitDataSet(dataSet, feat, val)
 71     retTree['left'] = createTree(lSet, leafType, errType, ops)
 72     retTree['right'] = createTree(rSet, leafType, errType, ops)
 73     return retTree
 74 
 75 
 76 def isTree(obj):
 77     return (type(obj).__name__ == 'dict')
 78 
 79 
 80 def getMean(tree):
 81     if isTree(tree['right']): tree['right'] = getMean(tree['right'])
 82     if isTree(tree['left']): tree['left'] = getMean(tree['left'])
 83     return (tree['left'] + tree['right']) / 2.0
 84 
 85 # 后剪枝(建树之后进行)
 86 # 准备工作:用训练集建树,构建的树要足够大,方便剪枝
 87 # 步骤:如果左右树任意一棵是树,则在该子集递归剪枝过程;否则,计算并比较两个叶节点合并后的误差和不合并的误差,若误差变小,则合并
 88 def prune(tree, testData):
 89     if shape(testData)[0] == 0: return getMean(tree)
 90     if (isTree(tree['right']) or isTree(tree['left'])):
 91         lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
 92     if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
 93     if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
 94     if not isTree(tree['left']) and not isTree(tree['right']):
 95         lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
 96         errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + \
 97                        sum(power(rSet[:, -1] - tree['right'], 2))
 98         treeMean = (tree['left'] + tree['right']) / 2.0
 99         errorMerge = sum(power(testData[:, -1] - treeMean, 2))
100         if errorMerge < errorNoMerge:
101             print ("merging")
102             return treeMean
103         else:
104             return tree
105     else:
106         return tree
107 
108 # 9-4 模型树的叶节点生成函数
109 def linearSolve(dataSet):
110     m,n = shape(dataSet)
111     X = mat(ones((m,n)))
112     Y = mat(ones((m,1)))
113     X[:,1:n] = dataSet[:,0:n-1]
114     Y = dataSet[:,-1]
115     xTx = X.T * X
116     if linalg.det(xTx) == 0.0:
117         raise NameError('This matrix is singular, cannot do inverse,\ntry increasing the second value of ops')
118         # 当程序出现错误,python会自动引发异常,也可以通过raise显示地引发异常。一旦执行了raise语句,raise后面的语句将不能执行。
119     ws = xTx.I * (X.T * Y)
120     return ws, X, Y
121 
122 
123 def modelLeaf(dataSet):
124     ws, X, Y = linearSolve(dataSet)
125     return ws
126 
127 
128 def modelErr(dataSet):
129     ws, X, Y = linearSolve(dataSet)
130     yHat = X * ws
131     return sum(power(Y - yHat, 2))
132 
133 myMat2 = mat(loadDataSet('exp2.txt'))
134 modelTrees = createTree(myMat2,modelLeaf,modelErr,(1,10))
135 print modelTrees

 

posted @ 2017-06-30 01:27  DianeSoHungry  阅读(625)  评论(0编辑  收藏  举报