【算法】决策树算法:ID3
决策树算法是一种用于分类和回归问题的机器学习算法。它通过构建一个树状结构来表示不同的决策路径,每个节点代表一个特征属性的判断,而每个分支代表一个可能的结果。在分类问题中,决策树算法可以根据输入的特征属性将实例分到不同的类别中;在回归问题中,决策树算法可以根据输入的特征属性预测一个连续值。
决策树算法的主要优点包括易于理解和解释、能够处理非线性关系和多重输出等特点。它的主要缺点包括对噪声和缺失数据敏感、容易过拟合等。常见的决策树算法包括ID3
、C4.5
、CART
等。
决策树算法的基本步骤包括选择最佳的划分属性
、划分数据集
、递归构建子树
和剪枝
等。在实际应用中,可以根据具体的问题和数据集选择适合的决策树算法,并通过调参和集成方法等手段提升模型性能。
ID3(Iterative Dichotomiser 3)是一种经典的决策树学习算法,它使用信息增益来选择最佳的划分属性。以下是ID3算法的优缺点:
-
优点:
- 相对简单:ID3算法易于实现和理解,因为它使用信息增益来进行属性选择,而信息增益的计算也相对直观。
- 可解释性强:生成的决策树易于解释和理解,可以清晰地展示不同特征属性对结果的影响。
- 能够处理离散型数据:ID3算法适用于离散型数据的分类问题,并且可以处理多类别的分类任务。
-
缺点:
- 对连续型数据不友好:ID3算法不能直接处理连续型数据,需要将连续型数据进行离散化处理。
- 对噪声和缺失数据敏感:ID3算法对噪声和缺失数据较为敏感,可能导致过拟合。
- 容易过拟合:ID3算法会倾向于创建复杂的树结构,容易过拟合训练数据,因此需要进行剪枝等操作来避免过拟合。
总的来说,ID3算法在处理离散型数据且数据质量较好的情况下表现良好,但在处理连续型数据和噪声较多的情况下可能表现不佳。
import math
from collections import Counter
# 创建数据集
def create_dataset():
dataset = [
# 年龄, 工作, 房子,信用,标签
['青年', 0, 0, '一般', '0'],
['青年', 0, 0, '好', '0'],
['青年', 1, 0, '好', '1'],
['青年', 1, 1, '一般', '1'],
['青年', 0, 0, '一般', '0'],
['中年', 0, 0, '一般', '0'],
['中年', 0, 0, '好', '0'],
['中年', 1, 1, '好', '1'],
['中年', 0, 1, '很好', '1'],
['中年', 0, 1, '很好', '1'],
['老年', 0, 1, '很好', '1'],
['老年', 0, 1, '好', '1'],
['老年', 1, 0, '好', '1'],
['老年', 1, 0, '很好', '1'],
['老年', 0, 0, '一般', '0']
]
return dataset
# 计算熵
def cal_entropy(dataset):
label_count = {}
# 统计样本标签
for item in dataset:
# 样本标签
label = item[-1]
# 不在字典中
if label not in label_count:
label_count[label] = 0
# 计数+1
label_count[label] += 1
# 计算熵
entropy = 0.0
for label in label_count:
# 概率 = 样本数 / 样本总数
p = label_count[label] / len(dataset)
# 计算熵
if p == 0:
continue
entropy -= p * math.log(p, 2)
return entropy
# 计算条件熵
def cal_cond_entropy(dataset, feature, value):
ret_dataset = []
for item in dataset:
if item[feature] == value:
# 抽取当前特征左侧的数据
except_item = item[:feature]
# 抽取当前特征右侧的数据
except_item.extend(item[feature + 1:])
ret_dataset.append(except_item)
return ret_dataset
# 计算信息增益
def cal_info_gain(dataset):
# 样本数
num_feature = len(dataset[0]) - 1
# 计算基本熵
base_entropy = cal_entropy(dataset)
# 最优的信息增益
best_info_gain = 0.0
# 最优的信息增益的索引
best_info_gain_feature = 0
for i in range(num_feature):
feature_list = [example[i] for example in dataset]
feature_set = set(feature_list)
conditional_entropy = 0.0
for value in feature_set:
# 计算条件熵
sub_dataset = cal_cond_entropy(dataset, i, value)
p = float(len(sub_dataset)) / len(dataset)
conditional_entropy += p * cal_entropy(sub_dataset)
info_gain = base_entropy - conditional_entropy
# 选取最大的信息索引
if info_gain > best_info_gain:
best_info_gain = info_gain
best_info_gain_feature = i
return best_info_gain_feature, best_info_gain
# 多数表决法决定叶子节点分类
def majority_cnt(class_list):
class_count = Counter(class_list)
sorted_class_count = sorted(class_count.items(), key=lambda x: x[1], reverse=True)
return sorted_class_count[0][0]
# 构建决策树
def build_decision_tree(dataset, labels):
class_list = [data[-1] for data in dataset]
if class_list.count(class_list[0]) == len(class_list): # 类别完全相同则停止继续划分
return class_list[0]
if len(dataset[0]) == 1: # 遍历完所有特征时返回出现次数最多的类别
return majority_cnt(class_list)
best_feat, best_info_gain = cal_info_gain(dataset)
best_feat_label = labels[best_feat]
my_tree = {best_feat_label: {}}
new_labels = labels[:]
del(new_labels[best_feat])
feat_values = [data[best_feat] for data in dataset]
unique_vals = set(feat_values)
for value in unique_vals:
sub_labels = new_labels[:]
my_tree[best_feat_label][value] = build_decision_tree(cal_cond_entropy(dataset, best_feat, value), sub_labels)
return my_tree
# 使用决策树进行分类
def classify(input_tree, feat_labels, test_data):
first_str = list(input_tree.keys())[0]
second_dict = input_tree[first_str]
feat_index = feat_labels.index(first_str)
key = test_data[feat_index]
value_of_feat = second_dict[key]
if isinstance(value_of_feat, dict):
class_label = classify(value_of_feat, feat_labels, test_data)
else:
class_label = value_of_feat
return class_label
# ID3 算法举例
if __name__ == '__main__':
dataset = create_dataset()
labels = ['年龄', '工作', '房子', '信用']
print("熵:", cal_entropy(dataset))
best_info_gain_feature, best_info_gain = cal_info_gain(dataset)
print("信息增益:", best_info_gain_feature, best_info_gain)
tree = build_decision_tree(dataset, labels)
print("决策树:", tree)
print("测试数据:", dataset[0])
result = classify(tree, labels, ['老年', 1, 0, '一般'])
print("预测结果:", result)
运行效果: