实验一:决策树
【附录1】
年龄 | 有工作 | 有自己的房子 | 信贷情况 | 类别 | |
0 | 青年 | 否 | 否 | 一般 | 否 |
1 | 青年 | 否 | 否 | 好 | 否 |
2 | 青年 | 是 | 否 | 好 | 是 |
3 | 青年 | 是 | 是 | 一般 | 是 |
4 | 青年 | 否 | 否 | 一般 | 否 |
5 | 中年 | 否 | 否 | 一般 | 否 |
6 | 中年 | 否 | 否 | 好 | 否 |
7 | 中年 | 是 | 是 | 好 | 是 |
8 | 中年 | 否 | 是 | 非常好 | 是 |
9 | 中年 | 否 | 是 | 非常好 | 是 |
10 | 老年 | 否 | 是 | 非常好 | 是 |
11 | 老年 | 否 | 是 | 好 | 是 |
12 | 老年 | 是 | 否 | 好 | 是 |
13 | 老年 | 是 | 否 | 非常好 | 是 |
14 | 老年 | 否 | 否 | 一般 |
否 |
【实验目的】
- 理解决策树算法原理,掌握决策树算法框架;
- 理解决策树学习算法的特征选择、树的生成和树的剪枝;
- 能根据不同的数据类型,选择不同的决策树算法;
- 针对特定应用场景及数据,能应用决策树算法解决实际问题。
【实验内容】
1、设计算法实现熵、经验条件熵、信息增益等方法
a.导入数据
①导入使用的安装包
1 2 3 4 5 6 7 8 9 10 | import numpy as np import pandas as pd import matplotlib.pyplot as plt % matplotlib inline from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from collections import Counter import math from math import log import pprint |
②导入所给数据
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | def create_data(): datasets = [[ '青年' , '否' , '否' , '一般' , '否' ], [ '青年' , '否' , '否' , '好' , '否' ], [ '青年' , '是' , '否' , '好' , '是' ], [ '青年' , '是' , '是' , '一般' , '是' ], [ '青年' , '否' , '否' , '一般' , '否' ], [ '中年' , '否' , '否' , '一般' , '否' ], [ '中年' , '否' , '否' , '好' , '否' ], [ '中年' , '是' , '是' , '好' , '是' ], [ '中年' , '否' , '是' , '非常好' , '是' ], [ '中年' , '否' , '是' , '非常好' , '是' ], [ '老年' , '否' , '是' , '非常好' , '是' ], [ '老年' , '否' , '是' , '好' , '是' ], [ '老年' , '是' , '否' , '好' , '是' ], [ '老年' , '是' , '否' , '非常好' , '是' ], [ '老年' , '否' , '否' , '一般' , '否' ], ] labels = [u '年龄' , u '有工作' , u '有自己的房子' , u '信贷情况' , u '类别' ] return datasets, labels |
③输出结果
1 2 3 | datasets, labels = create_data() train_data = pd.DataFrame(datasets, columns = labels) train_data |
b.熵
1 2 3 4 5 6 7 8 9 10 | def calc_ent(datasets): data_length = len (datasets) label_count = {} for i in range (data_length): label = datasets[i][ - 1 ] if label not in label_count: label_count[label] = 0 label_count[label] + = 1 ent = - sum ([(p / data_length) * log(p / data_length, 2 ) for p in label_count.values()]) return ent |
c.经验条件熵
1 2 3 4 5 6 7 8 9 10 | def cond_ent(datasets, axis = 0 ): data_length = len (datasets) feature_sets = {} for i in range (data_length): feature = datasets[i][axis] if feature not in feature_sets: feature_sets[feature] = [] feature_sets[feature].append(datasets[i]) cond_ent = sum ([( len (p) / data_length) * calc_ent(p) for p in feature_sets.values()]) return cond_ent |
d.信息增益
1 2 3 4 5 6 7 8 9 10 11 12 13 | def tz(ent, cond_ent): return ent - cond_ent def tz_train(datasets): count = len (datasets[ 0 ]) - 1 ent = calc_ent(datasets) best_feature = [] for c in range (count): c_tz = tz(ent, cond_ent(datasets, axis = c)) best_feature.append((c, c_tz)) print ( '{}{:.3f}' . format (labels[c], c_tz)) best_ = max (best_feature, key = lambda x: x[ - 1 ]) return '{}为根节点特征' . format (labels[best_[ 0 ]]) |
1 | tz_train(np.array(datasets)) |
2、针对给定的房贷数据集(数据集表格见附录1)实现ID3算法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | # 定义节点类 二叉树 class Node: def __init__( self , root = True , label = None , feature_name = None , feature = None ): self .root = root self .label = label self .feature_name = feature_name self .feature = feature self .tree = {} self .result = { 'label:' : self .label, 'feature' : self .feature, 'tree' : self .tree} def __repr__( self ): return '{}' . format ( self .result) def add_node( self , val, node): self .tree[val] = node def predict( self , features): if self .root is True : return self .label return self .tree[features[ self .feature]].predict(features) class DTree: def __init__( self , epsilon = 0.1 ): self .epsilon = epsilon self ._tree = {} # 熵 @staticmethod def calc_ent(datasets): data_length = len (datasets) label_count = {} for i in range (data_length): label = datasets[i][ - 1 ] if label not in label_count: label_count[label] = 0 label_count[label] + = 1 ent = - sum ([(p / data_length) * log(p / data_length, 2 ) for p in label_count.values()]) return ent # 经验条件熵 def cond_ent( self , datasets, axis = 0 ): data_length = len (datasets) feature_sets = {} for i in range (data_length): feature = datasets[i][axis] if feature not in feature_sets: feature_sets[feature] = [] feature_sets[feature].append(datasets[i]) cond_ent = sum ([( len (p) / data_length) * self .calc_ent(p) for p in feature_sets.values()]) return cond_ent # 信息增益 @staticmethod def info_gain(ent, cond_ent): return ent - cond_ent def info_gain_train( self , datasets): count = len (datasets[ 0 ]) - 1 ent = self .calc_ent(datasets) best_feature = [] for c in range (count): c_info_gain = self .info_gain(ent, self .cond_ent(datasets, axis = c)) best_feature.append((c, c_info_gain)) # 比较大小 best_ = max (best_feature, key = lambda x: x[ - 1 ]) return best_ def train( self , train_data): """ input:数据集D(DataFrame格式),特征集A,阈值eta output:决策树T """ _, y_train, features = train_data.iloc[:, : - 1 ], train_data.iloc[:, - 1 ], train_data.columns[: - 1 ] # 1,若D中实例属于同一类Ck,则T为单节点树,并将类Ck作为结点的类标记,返回T if len (y_train.value_counts()) = = 1 : return Node(root = True , label = y_train.iloc[ 0 ]) # 2, 若A为空,则T为单节点树,将D中实例树最大的类Ck作为该节点的类标记,返回T if len (features) = = 0 : return Node(root = True , label = y_train.value_counts().sort_values(ascending = False ).index[ 0 ]) # 3,计算最大信息增益 同5.1,Ag为信息增益最大的特征 max_feature, max_info_gain = self .info_gain_train(np.array(train_data)) max_feature_name = features[max_feature] # 4,Ag的信息增益小于阈值eta,则置T为单节点树,并将D中是实例数最大的类Ck作为该节点的类标记,返回T if max_info_gain < self .epsilon: return Node(root = True , label = y_train.value_counts().sort_values(ascending = False ).index[ 0 ]) # 5,构建Ag子集 node_tree = Node(root = False , feature_name = max_feature_name, feature = max_feature) feature_list = train_data[max_feature_name].value_counts().index for f in feature_list: sub_train_df = train_data.loc[train_data[max_feature_name] = = f].drop([max_feature_name], axis = 1 ) # 6, 递归生成树 sub_tree = self .train(sub_train_df) node_tree.add_node(f, sub_tree) # pprint.pprint(node_tree.tree) return node_tree def fit( self , train_data): self ._tree = self .train(train_data) return self ._tree def predict( self , X_test): return self ._tree.predict(X_test) datasets, labels = create_data() data_df = pd.DataFrame(datasets, columns = labels) dt = DTree() tree = dt.fit(data_df) tree |
输出结果:
1 | dt.predict([ '中年' , '是' , '否' , '好' ]) |
3、针对iris数据集,应用sklearn的决策树算法进行类别预测
def create_data(): iris = load_iris() df = pd.DataFrame(iris.data, columns=iris.feature_names) df['label'] = iris.target df.columns = [ 'sepal length', 'sepal width', 'petal length', 'petal width', 'label' ] data = np.array(df.iloc[:100, [0, 1, -1]]) return data[:, :2], data[:, -1] X, y = create_data() X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3) from sklearn.tree import DecisionTreeClassifier from sklearn.tree import export_graphviz import graphviz clf = DecisionTreeClassifier() clf.fit(X_train, y_train,
clf.score(X_test, y_test)
1 2 3 4 | tree_pic = export_graphviz(clf, out_file = "mytree.pdf" ) with open ( 'mytree.pdf' ) as f: dot_graph = f.read() graphviz.Source(dot_graph) |
【实验报告总结】
1、查阅文献,讨论ID3、C4.5算法的应用场景
a.ID3
ID3算法在机器学习、知识发现和数据挖掘等领域有巨大作用。它的基础理论清晰,算法比较简单,学习能力较强,适于处理大规模的学习问题,是数据挖掘和知识发现领域中的一个很好的范例,为后来各学者提出优化算法奠定了理论基础。
b.C4.5
C4.5算法在机器学习、知识发现、金融分析、遥感影像分类、生产制造、分子生物学和数据挖掘等领域有广泛应用。它具有条理清晰,能处理连续型属性,防止过拟合,准确率较高和适用范围广等优点,是一个很有实用价值的决策树算法,可以用来分类,也可以用来回归。
2、分析决策树剪枝策略
a.预剪枝
预剪枝的方法主要包括以下几种:
1、树的高度限制:设定树的高度最大值,当达到限定值时,停止树的生长;
2、训练样本限制:对一个拥有较少训练样本的节点进行分裂时容易出现过拟合现象,因此设定样本量阀值,当样本量少于阀值时停止生长;
3、系统性能增益:当属性的信息增益小于某个指定的阀值时停止增长;
4、纯度限制:当该节点中某个类别的占比超过指定阀值时,停止生长。
b.后剪枝
在实际的应用中,后剪枝比预剪枝具有更广的应用。后剪枝算法之间的差异包括以下几个方面:自上而下还是自下而上、是否需要独立的剪枝数据集、节点的误差估计方法、剪枝条件。主要的剪枝方法有降低错误剪枝、悲观错误剪枝、基于错误剪枝、代价-复杂度剪枝、最小错误剪枝。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了