机器学习Sklearn系列:(三)决策树
决策树
熵的定义
如果一个随机变量X的可能取值为X={x1,x2,..,xk},其概率分布为P(X=x)=pi(i=1,2,...,n),则随机变量X的熵定义为\(H(x) = -\sum{p(x)logp(x)}=\sum{p(x)log{\frac{1}{p(x)}}}\)。需要注意的是,熵越大,随机变量的不确定性就越大。
当n = 2的时候,\(H(p)=-plogp-(1-p)log(1-p)\)也就是交叉熵的损失函数。
条件熵
条件熵主要是用来计算,在莫一列数据X选中的条件下,其标签Y的熵大小,这样可以帮助计算,那一列数据对应的标签更加简洁易分。 条件熵计算公式如下,其中,\(p_i =P(X=x_i)\)
具体来说,条件熵公式如何使用到结构化的数据中来的,这里的X表示的某一列的特征,Xi表示该特征的一个子类特征,这里\(H(Y_i)\)表示Xi这一类子特征对应的标签Y的熵。K表示标签的类别,下面公式中,\(Y_{ik}\),表示第Xi类特征对应的标签\(Y_i\)的种类。
举个具体的例子:
特征X | 标签Y |
---|---|
1 | 1 |
1 | 0 |
1 | 1 |
1 | 1 |
2 | 1 |
2 | 1 |
2 | 0 |
这对特征Xi = 1的条件熵的计算如下:
信息增益
信息增益的计算方式如下,其中,由于H(D)是个固定值,H(D|A)越小,信息增益就越大,这样这个特征就越简洁,也就是说这个特征能够最大化的去区分label , 这里 X代表的是莫一列特征,Y代表的是数据集的标签。
决策树算法
ID3
ID3 算法和原理就是,使用信息增益来挑选特征,优先挑选信息增益最大的特征。其具体决策树生成过程如下:
1. 首先计算所有特征的信息增益,挑选一个最大的特征,作为节点的特征
2. 对挑选出来的子节点递归调用方法 1
3. 当特征信息增益小于阈值,或者没有特征可以选择,或者可选特征小于阈值等,停止。
C4.5算法
上述算法有一个问题,假设特征X有两列特征,其信息增益差不多,但是某一列数据特别混乱,这个时候应该避免选择这一列作为根结点,而C4.5算法的核心就是通过给信息增益下面,除一个这一列特征的熵,从而减少这一列数据的信息增益。也就是说,如果某一列特征越混乱,那么其最终得到的信息增益就越小,从而避免了上述的问题。 具体公式为:
其中,n为
CART算法
CART算法的思路和上面两个算法是一样的,只不过这里用来评估特征混乱度的方法是用的基尼指数。其中,基尼指数越大,不确定性越大,和熵是类似的。
基尼指数的定义如下:其中,\(p_k\)为样本点属于第k类的概率。
如果将基尼指数用到结构化数据集中:
在特征为X标签为Y的条件下,其基尼指数为:其中,\(Y_1,Y_2\)表示,特征X下的子类别\(X_1,X_2\)对应的标签。
决策树剪肢
决策树减肢可以减轻决策树的复杂度,同时确保决策树能够保持一定的正确率,剪肢的方法,一般是从最深的一层开始,减去节点,然后看accuracy,如果accuracy提升了,就可以减去。也可以使用其他基于阈值的方法,例如下一层的不纯度低于某个阈值,就可以直接不分裂等等。
sklearn中决策树的使用
参数
class sklearn.tree.DecisionTreeClassifier(*, criterion='gini',
splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1,
min_weight_fraction_leaf=0.0, max_features=None, random_state=None,
max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None,
class_weight=None, ccp_alpha=0.0)
criterion{“gini”, “entropy”}, default=”gini” 确定决策树是基于基尼指数还是熵
max_depth 树的层数,限制树的最大深度,超过设定深度的树枝全部剪掉,这是用得最广泛的剪枝参数,在高维度低样本量时非常有效。决策树多生长一层,对样本量的需求会增加一倍,所以限制树深度能够有效地限制过拟合。在集成算法中也非常实用。实际使用时,建议从=3开始尝试,看看拟合的效果再决定是否增加设定深度。实际层数为 max_depth +1 考虑根
min_impurity_decrease 限制决策树的生长,如果节点的不纯度(GINI,GAIN)小于这个阈值,就不在生成子节点
min_impurity_split :不纯度必须大于这个值,不然不分裂
min_impurity_decrease限制信息增益的大小,信息增益小于设定数值的分枝不会发生。这是在0.19版本种更新的功能,在0.19版本之前时使用min_impurity_split。
random_state 随机数种子,固定种子之后,训练的模型是一样的
class_weight 可以用来定义某一个类别的权重,让这一个类比在计算的时候,信息增益变得稍微大一些
splitter也是用来控制决策树中的随机选项的,有两种输入值,输入”best",决策树在分枝时虽然随机,但是还是会优先选择更重要的特征进行分枝(重要性可以通过属性feature_importances_查看),输入“random",决策树在分枝时会更加随机,树会因为含有更多的不必要信息而更深更大,并因这些不必要信息而降低对训练集的拟合。这也是防止过拟合的一种方式。当你预测到你的模型会过拟合,用这两个参数来帮助你降低树建成之后过拟合的可能性。当然,树一旦建成,我们依然是使用剪枝参数来防止过拟合。所以要想泛化好,最好splitter设置成random。
和剪肢相关的参数:
min_samples_leaf ** ** 限定,一个节点在分枝后的每个子节点都必须包含至少min_samples_leaf个训练样本,否则分枝就不会发生,或者,分枝会朝着满足每个子节点都包含min_samples_leaf个样本的方向去发生
一般搭配max_depth使用,在回归树中有神奇的效果,可以让模型变得更加平滑。这个参数的数量设置得太小会引起过拟合,设置得太大就会阻止模型学习数据。一般来说,建议从=5开始使用。如果叶节点中含有的样本量变化很大,建议输入浮点数作为样本量的百分比来使用。同时,这个参数可以保证每个叶子的最小尺寸,可以在回归问题中避免低方差,过拟合的叶子节点出现。对于类别不多的分类问题,=1通常就是最佳选择。
min_samples_split限定,一个节点必须要包含至少min_samples_split个训练样本,这个节点才允许被分枝,否则分枝就不会发生。
max_features限制分枝时考虑的特征个数,比如一个样本特征为13个,限制之后只能使用有限个特征进行分类任务。超过限制个数的特征都会被舍弃。和max_depth异曲同工,max_features是用来限制高维度数据的过拟合的剪枝参数,但其方法比较暴力,是直接限制可以使用的特征数量而强行使决策树停下的参数,在不知道决策树中的各个特征的重要性的情况下,强行设定这个参数可能会导致模型学习不足。如果希望通过降维的方式防止过拟合,建议使用PCA,ICA或者特征选择模块中的降维算法。
class_weight 标签权重,给某一类的标签更大的权重,当样本不均衡的时候,可以考虑使用
min_weight_fraction_leaf有了权重之后,样本量就不再是单纯地记录数目,而是受输入的权重影响了,因此这时候剪枝,就需要搭配min_ weight_fraction_leaf这个基于权重的剪枝参数来使用。另请注意,基于权重的剪枝参数(例如min_weight_ fraction_leaf)将比不知道样本权重的标准(比如min_samples_leaf)更少偏向主导类。如果样本是加权的,则使用基于权重的预修剪标准来更容易优化树结构,这确保叶节点至少包含样本权重的总和的一小部分。
注意在sklearn中实现的决策树都是二叉树
使用模型训练数据:
model = tree.DecisionTreeClassifier()
model.fit(X,y)
model.predict(X_val)
sklearn中,可以输出决策树特征的重要性
clf.feature_importances_
回归树
在分类问题中决策树的每一片叶子都代表的是一个 class;在回归问题中,决策树的每一片叶子表示的是一个预测值,取值是连续的。
决策树还可以做回归任务,回归树种的参数和上面分类树的参数是一模一样的,唯一的区别是,回归树没有class_weight这个参数,因为没有类别不平衡这个说法
X = [[0, 0], [2, 2]]
y = [0.5, 2.5]
clf = tree.DecisionTreeRegressor()
clf = clf.fit(X, y)
clf.predict([[1, 1]])
array([0.5])
决策树可视化
可以使用graphviz安装包:pip install graphviz
一个例子:
这里的参数,class_names 表示类别的名称,filled表示填充颜色,rounded 表示框的形状
feature_name = ["A","B","C"]
import graphviz
dot_data = tree.export_graphviz(clf
,feature_names= feature_name
,class_names=["1","2","3"]
,filled=True
,rounded=True
,out_file=None
)
graph = graphviz.Source(dot_data)
graph