树回归

树回归

当回归的数据呈现非线性时,就需要使用树回归。

树回归的基本逻辑

获得最好的切分特征和切分特征值

  遍历所有特征

    针对某一特征,遍历该特征的所有值

    针对某一特征值,进行划分数据,计算出划分数据之后的总方差,

    若总方差最小,记下特征和特征值

    当遍历完所有特征后,就能够获得最小方差的特征和特征值,并以此作为树的结点,划分左右子树,

若没有特征,就返回特征值

左子树为大于等于特征值的样本集合,

右子树为小于特征值的样本集合

构建字典,保存特征、特征值、左子树和右子树

 

针对各个子树,再重复上面的操作。

终止的条件为:

样本数小于某个数值(如4)和总方差下降的最低值(如1)

叶子结点取目标值的平均值

 

最后构建出一颗倒着长的二叉树,内部结点为特征和特征值,叶子结点为某一数值,表示在特定路径下的预测目标值。

剪枝

思路:

1.回归树到达左右叶节点

2.该树的特征和特征值二元划分测试数据

3.回归树的左右叶节点分别测试树的左右目标值作差值的平方,然后求和

4.比较二元分裂前后的平方差大小。即误差大小

5.若合并后误差减少,则返回合并的值。否则,返回当前回归树

详说:

到达左右叶节点,说明这里是一个剪枝的机会点,至于是否剪枝,需要根据测试数据来判断。

这个时候回归树仍然有特征index和特征值,但没有叶节点。使用回归树的特征和特征值切分测

是数据,用左叶节点的值和切分后的左子集的目标值依次作差值并平方,右节点也右子集作同样

操作,比较大小,若切分前的集合目标值与左右叶节点平均值之间的差值并平方所得到误差要小,

说明归回树的合并是成功的,就可以返回该合并值。否则的话,返回该回归树,为什么呢?还没

有想明白。

现附上《机器学习实践》中的源代码:

    def prune(self, tree, testData):
        if shape(testData)[0] == 0: return self.getMean(tree) #if we have no test data collapse the tree
        if (self.isTree(tree['right']) or self.isTree(tree['left'])):#if the branches are not trees try to prune them
            lSet, rSet = self.binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        if self.isTree(tree['left']): tree['left'] = self.prune(tree['left'], lSet)
        if self.isTree(tree['right']): tree['right'] =  self.prune(tree['right'], rSet)
        #if they are now both leafs, see if we can merge them
        if not self.isTree(tree['left']) and not self.isTree(tree['right']):
            lSet, rSet = self.binSplitDataSet(testData, tree['spInd'], tree['spVal'])
            errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\
                sum(power(rSet[:,-1] - tree['right'],2))
            treeMean = (tree['left']+tree['right'])/2.0
            errorMerge = sum(power(testData[:,-1] - treeMean,2))
            if errorMerge < errorNoMerge: 
                print ("merging")
                return treeMean
            else: return tree
        else: return tree

 

 

测试数据可以使用index和特质值进行二元分裂,从而产生两个子样本集

代码:

posted @ 2020-08-28 11:33  绍荣  阅读(241)  评论(0编辑  收藏  举报