机器学习-决策树系列-决策树-剪枝-CART算法-27

1. 剪枝

将子树还原成一个叶子节点:

是解决过拟合的一个有效方法。当树训练得过于茂盛的时候会出现在测试集上的效果比训练集上差不少的现象,即过拟合。可以采用如下两种剪枝策略。

前剪枝,设置超参数抑制树的生长, 例如:max_depth max_leaf_nodes

2. CCP—代价复杂度剪枝(CART)

决策树构建好后,然后才开始裁剪, 叫做post pruning 后剪枝。
将 子树用一个叶子节点代替,降低了准确率(增加了错误率)这个是cost(代价),但是剪枝降低了树的复杂性complexity,
参考sklearn官网:https://scikit-learn.org/stable/auto_examples/tree/plot_cost_complexity_pruning.html#sphx-glr-auto-examples-tree-plot-cost-complexity-pruning-py
Greater values of ccp_alpha increase the number of nodes pruned.越大的α值 被剪枝的节点就越多。

该算法为子树Tt定义了代价(cost)和复杂度(complexity),以及一个可由用户设置的衡量代价与复杂度之间关系的参数α

|N1|:子树Tt中的叶节点数;
还原成一个叶子结点,减少的叶子节点数 就是:|N1|-1
R(t):结点t的错误代价,计算公式为R(t)=r(t)*p(t),r(t)为结点t的错分样本率,p(t)为落入结点t的样本占所有样本的比例;
R(Tt):子树Tt错误代价,计算公式为R(Tt)=∑R(i)。

举个例子:

4. α值的确定

#!/usr/bin/env python
# coding: utf-8

# In[3]:


import matplotlib.pyplot as plt

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier

X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train,y_test = train_test_split(X, y, random_state=0)


clf = DecisionTreeClassifier(random_state=0)
path = clf.cost_complexity_pruning_path(X_train, y_train)

ccp_alphas, impurities = path.ccp_alphas, path.impurities
ccp_alphas, impurities

fig, ax = plt.subplots()
ax.plot(ccp_alphas[:-1], impurities[:-1], marker="o", drawstyle='steps-post')
ax.set_xlabel("eeffective alpha")
ax.set_ylabel("total impurity of leaves")
ax.set_title("Total Impurity vs effective alpha for traing set")

clfs = []

for ccp_alpha in ccp_alphas:
    clf = DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha)
    clf.fit(X_train, y_train)
    clfs.append(clf)
clfs[-1].tree_.node_count, ccp_alphas[-1]

ccp_alphas = ccp_alphas[:-1]
clfs = clfs[:-1]
train_scores = [clf.score(X_train, y_train) for clf in clfs]
test_scores = [clf.score(X_test, y_test) for clf in clfs]


fig, ax = plt.subplots()
ax.set_xlabel("alpha")
ax.set_ylabel("accuracy")
ax.set_title("Accuracy vs alpha for training and test sets")
ax.plot(ccp_alphas, train_scores, marker="o", label="train", drawstyle="steps-post")
ax.plot(ccp_alphas, test_scores, marker="o", label="test", drawstyle="steps-post")
ax.legend()
plt.show()



0.015是理想的alpha

posted @ 2024-01-04 00:14  jack-chen666  阅读(69)  评论(0编辑  收藏  举报