决策树算法笔记及ID3算法的Python实现
决策树是多叉树,决策树是彼此互斥且完备的一系列的if-then规则。
决策树还可以看成给定条件下类的条件概率分布,每条路径对应于划分中的一个单元。
决策树的损失函数通常是正则化的极大似然函数。
决策树的核心算法就是对可能的决策树空间进行自上而下的贪心搜索。
特征选择,怎么选特征:
熵的解释:信息论中对熵的解释就是熵是对S任一成员的分类信息进行编码所需的最少的bit数量。如果\(p\)是1,那么接收器知道是正类,不需要额外信息,因此熵是0;如果\(p\)是0.5,那么需要用1个bit来编码是正类还是负类;如果\(p\)是0.8,那么一串s的平均所需编码是少于1个bit的。
因为熵是对信息编码所需的bit的长度的期望值,因此对数的底是2。
信息增益就是知道属性A后,熵减小的值;也可以理解为知道属性A的值后,对信息编码所需的bit的长度的减小的值。
熵\(H(Y)\)与熵\(H(Y|X)\)的之差又称为互信息。即决策树中的信息增益等价于训练数据集中类与特征的互信息。
为什么叫互信息?
根据熵的定义
同理,\(H(X,Y) = H(Y) + H(X|Y)\),因此
\(I(X;Y)\)即是\(X\)和\(Y\)的互信息,即知道\(Y\)之后\(X\)的熵减和知道\(X\)之后\(Y\)的熵减是相同的,彼此相互提供的信息量是相同的。
互信息是对称的,非负的。用于表示信息之间的关系, 是两个随机变量统计相关性的测度。
信息增益的问题:
信息增益倾向于选择取值数量多的特征。举个例子,如果把id作为一个特征加入到训练中,那么根据信息增益公司,id这个特征的信息增益最大,因为每个id的熵都是0。然而,id这个特征对于预测基本是毫无用处的。改进方案是信息增益比。
信息增益比:
\(SplitInformation(S, A)\)是训练集关于特征A的熵,作为对id这种特征的惩罚项。
然而,信息增益比也有问题,那就是当\(|S_i| \approx |S|\)时,分母接近于0,使得信息增益比过大。
ID3算法
图为(Machine Learning, Tom Mitchell, McGraw Hill, 1997.)中的图
CART树
CART树是二叉树。
归纳偏置(inductive bias):
维基百科定义:当学习器去预测其未遇到过的输入的结果时,会做一些假设(Mitchell, 1980)。而学习算法中的归纳偏置则是这些假设的集合。
ID3决策树的归纳偏置:给定一些样本,通常有很多决策树符合这些样本。那么ID3的归纳偏置如何选择其中的决策树呢?ID3选择遇到的第一个可接受的树。ID3倾向于选择短树,倾向于选择信息增益最大的属性最靠近根节点的树。
决策树优缺点总结:
- 优点:
- 可解释性,可视化
- 需要的数据预处理较少。即,不需要归一化、dummy变量、缺失值处理等
- 预测的开销是训练的对数倍
- 可以处理数值数据和类型数据
- 缺点:
- 可能产生过拟合。需要剪枝,或者调节树的深度、叶子结点的最小样本个数等超参数
- 生成的决策树不稳定,数据的微小变动可能生成完全不同的树,可通过集成(ensemble)来缓解。
- 由于在模型空间中采用的是贪心搜索,因此可能最终的树是局部最优,同样可通过集成缓解(随机森林)。
- 如果类别很不平衡,会生成偏置的树。因此建议对数据进行类别平衡。
ID3算法的python简单实现
仅实现了ID3算法及剪枝,仅支持离散特征值
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from collections import Counter
import math
class Node:
def __init__(self, attr=None, cols=None):
self.attr = None
self.children = {}
self.label = None
self.data_len = None
self.entro = None
self.cols = cols
def __repr__(self):
if not self.children:
return self.label
else:
if not self.cols:
attr = self.attr
else:
attr = self.cols[self.attr]
return '{}: {}'.format(attr, self.children)
class DecisionTree:
def __init__(self, alpha=0.1, cols=None):
self.root = None
self.alpha = alpha
self.cols = cols
def fit(self, X, y):
attrs = list(range(len(X[0])))
self.root = self._fit(X, y, attrs)
return self
def _fit(self, X, y, attrs):
counter = Counter(y)
root = Node(cols=self.cols)
root.data_len = len(X)
root.entro = self._entropy(y)
label = max(counter, key=lambda x: counter[x])
root.label = label
if len(counter) == 1 or len(attrs) == 0:
return root
else:
attr = self.get_best_split(X, y, attrs)
val_dict = {}
# split train data according to attr
for x, y_ in zip(X, y):
val = x[attr]
if val not in val_dict: val_dict[val] = [[], []]
val_dict[val][0].append(x)
val_dict[val][1].append(y_)
attrs.remove(attr)
for k, (x, y_) in val_dict.items():
root.children[k] = self._fit(x, y_, attrs)
root.attr = attr
return root
def predict(self, X):
root = self.root
while root.attr:
attr = X[root.attr]
root = root.children[attr]
return root.label
def _gini(self, x):
pass
def _entropy(self, x):
counter = Counter(x)
length = len(x)
entro_list = [-(count/length) * math.log((count/length), 2) for count in counter.values()]
return sum(entro_list)
def _information(self, x):
pass
def get_best_split(self, X, y, attrs):
length = len(y)
ent_min = float('inf')
best_attr = attrs[0]
for attr in attrs:
ent_tmp = 0
y_dict = {}
for x, y_ in zip(X, y):
if x[attr] not in y_dict:
y_dict[x[attr]] = [y_]
else:
y_dict[x[attr]].append(y_)
for y_split in y_dict.values():
ent_tmp += len(y_split) / length * self._entropy(y_split)
if ent_tmp < ent_min:
ent_min = ent_tmp
best_attr = attr
return best_attr
def prune(self):
root = self.root
self.recur_entro(root)
def recur_entro(self, root):
if not root:
return None
# root cost
cost_root = root.entro * root.data_len + self.alpha
# children cost
cost, entro, leaf_num = 0, 0, 0
if root.children:
for child in root.children.values():
# cost(t) = sum(node_num_t * entropy_t for t in root.children) + alhpa * leaf_num
entro_child, leaf_child = self.recur_entro(child)
entro += entro_child
leaf_num += leaf_child
cost = entro + self.alpha * leaf_num
if cost < cost_root:
return entro, leaf_num
else:
root.children = None
root.attr = None
return root.entro, 1
else:
return root.entro * root.data_len, 1
def __repr__(self):
return '{}'.format(self.root)
X = [['青年', '否', '否', '一般', ],
['青年', '否', '否', '好', ],
['青年', '是', '否', '好', ],
['青年', '是', '是', '一般', ],
['青年', '否', '否', '一般', ],
['中年', '否', '否', '一般', ],
['中年', '否', '否', '好', ],
['中年', '是', '是', '好', ],
['中年', '否', '是', '非常好', ],
['中年', '否', '是', '非常好', ],
['老年', '否', '是', '非常好', ],
['老年', '否', '是', '好', ],
['老年', '是', '否', '好', ],
['老年', '是', '否', '非常好', ],
['老年', '否', '否', '一般', ],
]
y = ['否', '否', '是', '是', '否', '否', '否', '是', '是', '是', '是', '是', '是', '是', '否']
cols = [u'年龄', u'有工作', u'有自己的房子', u'信贷情况']
clf = DecisionTree(alpha=10, cols=cols)
clf = clf.fit(X, y)
print(clf)
test = ['青年', '是', '否', '一般']
print(clf.predict(test))
clf.prune()
print(clf)
print(clf.predict(test))
运行结果:
剪枝后,决策树发生了变化,对测试样本的预测结果也发生了变化。
参考:
李航. (2012). 统计学习方法. 清华大学出版社. 北京
sklearn关于tree的文档
Machine Learning, Tom Mitchell, McGraw Hill, 1997.
互信息——百度百科