Fork me on GitHub

青禹小生

雁驰万里却作禽,鱼未得水空有鳞。 花开花落花不语,昨是昨非昨亦今。

导航

python实现简单决策树(信息增益)——基于周志华的西瓜书数据

数据集如下:

 1 色泽    根蒂    敲声    纹理    脐部    触感    好瓜
 2 青绿    蜷缩    浊响    清晰    凹陷    硬滑    是
 3 乌黑    蜷缩    沉闷    清晰    凹陷    硬滑    是
 4 乌黑    蜷缩    浊响    清晰    凹陷    硬滑    是
 5 青绿    蜷缩    沉闷    清晰    凹陷    硬滑    是
 6 浅白    蜷缩    浊响    清晰    凹陷    硬滑    是
 7 青绿    稍蜷    浊响    清晰    稍凹    软粘    是
 8 乌黑    稍蜷    浊响    稍糊    稍凹    软粘    是
 9 乌黑    稍蜷    浊响    清晰    稍凹    硬滑    是
10 乌黑    稍蜷    沉闷    稍糊    稍凹    硬滑    否
11 青绿    硬挺    清脆    清晰    平坦    软粘    否
12 浅白    硬挺    清脆    模糊    平坦    硬滑    否
13 浅白    蜷缩    浊响    模糊    平坦    软粘    否
14 青绿    稍蜷    浊响    稍糊    凹陷    硬滑    否
15 浅白    稍蜷    沉闷    稍糊    凹陷    硬滑    否
16 乌黑    稍蜷    浊响    清晰    稍凹    软粘    否
17 浅白    蜷缩    浊响    模糊    平坦    硬滑    否
18 青绿    蜷缩    沉闷    稍糊    稍凹    硬滑    否

基于信息增益的ID3决策树的原理这里不再赘述,读者如果不明白可参考西瓜书对这部分内容的讲解。

python实现代码如下:

  1 from math import log2
  2 import pandas as pd
  3 import matplotlib.pyplot as plt
  4 from matplotlib.font_manager import FontProperties
  5 
  6 # 统计label出现次数
  7 def get_counts(data):
  8     total = len(data)
  9     results = {}
 10     for d in data:
 11         results[d[-1]] = results.get(d[-1], 0) + 1
 12     return results, total
 13 
 14 # 计算信息熵
 15 def calcu_entropy(data):
 16     results, total = get_counts(data)
 17     ent = sum([-1.0*v/total*log2(v/total) for v in results.values()])
 18     return ent
 19 
 20 # 计算每个feature的信息增益
 21 def calcu_each_gain(column, update_data):
 22     total = len(column)
 23     grouped = update_data.iloc[:, -1].groupby(by=column)
 24     temp = sum([len(g[1])/total*calcu_entropy(g[1]) for g in list(grouped)])
 25     return calcu_entropy(update_data.iloc[:, -1]) - temp
 26 
 27 # 获取最大的信息增益的feature
 28 def get_max_gain(temp_data):
 29     columns_entropy = [(col, calcu_each_gain(temp_data[col], temp_data)) for col in temp_data.iloc[:, :-1]]
 30     columns_entropy = sorted(columns_entropy, key=lambda f: f[1], reverse=True)
 31     return columns_entropy[0]
 32 
 33 # 去掉数据中已存在的列属性内容
 34 def drop_exist_feature(data, best_feature):
 35     attr = pd.unique(data[best_feature])
 36     new_data = [(nd, data[data[best_feature] == nd]) for nd in attr]
 37     new_data = [(n[0], n[1].drop([best_feature], axis=1)) for n in new_data]
 38     return new_data
 39 
 40 # 获得出现最多的label
 41 def get_most_label(label_list):
 42     label_dict = {}
 43     for l in label_list:
 44         label_dict[l] = label_dict.get(l, 0) + 1
 45     sorted_label = sorted(label_dict.items(), key=lambda ll: ll[1], reverse=True)
 46     return sorted_label[0][0]
 47 
 48 # 创建决策树
 49 def create_tree(data_set, column_count):
 50     label_list = data_set.iloc[:, -1]
 51     if len(pd.unique(label_list)) == 1:
 52         return label_list.values[0]
 53     if all([len(pd.unique(data_set[i])) ==1 for i in data_set.iloc[:, :-1].columns]):
 54         return get_most_label(label_list)
 55     best_attr = get_max_gain(data_set)[0]
 56     tree = {best_attr: {}}
 57     exist_attr = pd.unique(data_set[best_attr])
 58     if len(exist_attr) != len(column_count[best_attr]):
 59         no_exist_attr = set(column_count[best_attr]) - set(exist_attr)
 60         for nea in no_exist_attr:
 61             tree[best_attr][nea] = get_most_label(label_list)
 62     for item in drop_exist_feature(data_set, best_attr):
 63         tree[best_attr][item[0]] = create_tree(item[1], column_count)
 64     return tree
 65 
 66 # 决策树绘制基本参考《机器学习实战》书内的代码以及博客:http://blog.csdn.net/c406495762/article/details/76262487
 67 # 获取树的叶子节点数目
 68 def get_num_leafs(decision_tree):
 69     num_leafs = 0
 70     first_str = next(iter(decision_tree))
 71     second_dict = decision_tree[first_str]
 72     for k in second_dict.keys():
 73         if isinstance(second_dict[k], dict):
 74             num_leafs += get_num_leafs(second_dict[k])
 75         else:
 76             num_leafs += 1
 77     return num_leafs
 78 
 79 # 获取树的深度
 80 def get_tree_depth(decision_tree):
 81     max_depth = 0
 82     first_str = next(iter(decision_tree))
 83     second_dict = decision_tree[first_str]
 84     for k in second_dict.keys():
 85         if isinstance(second_dict[k], dict):
 86             this_depth = 1 + get_tree_depth(second_dict[k])
 87         else:
 88             this_depth = 1
 89         if this_depth > max_depth:
 90             max_depth = this_depth
 91     return max_depth
 92 
 93 # 绘制节点
 94 def plot_node(node_txt, center_pt, parent_pt, node_type):
 95     arrow_args = dict(arrowstyle='<-')
 96     font = FontProperties(fname=r'C:\Windows\Fonts\STXINGKA.TTF', size=15)
 97     create_plot.ax1.annotate(node_txt, xy=parent_pt,  xycoords='axes fraction', xytext=center_pt,
 98                             textcoords='axes fraction', va="center", ha="center", bbox=node_type,
 99                             arrowprops=arrow_args, FontProperties=font)
100 
101 # 标注划分属性
102 def plot_mid_text(cntr_pt, parent_pt, txt_str):
103     font = FontProperties(fname=r'C:\Windows\Fonts\MSYH.TTC', size=10)
104     x_mid = (parent_pt[0] - cntr_pt[0]) / 2.0 + cntr_pt[0]
105     y_mid = (parent_pt[1] - cntr_pt[1]) / 2.0 + cntr_pt[1]
106     create_plot.ax1.text(x_mid, y_mid, txt_str, va="center", ha="center", color='red', FontProperties=font)
107 
108 # 绘制决策树
109 def plot_tree(decision_tree, parent_pt, node_txt):
110     d_node = dict(boxstyle="sawtooth", fc="0.8")
111     leaf_node = dict(boxstyle="round4", fc='0.8')
112     num_leafs = get_num_leafs(decision_tree)
113     first_str = next(iter(decision_tree))
114     cntr_pt = (plot_tree.xoff + (1.0 +float(num_leafs))/2.0/plot_tree.totalW, plot_tree.yoff)
115     plot_mid_text(cntr_pt, parent_pt, node_txt)
116     plot_node(first_str, cntr_pt, parent_pt, d_node)
117     second_dict = decision_tree[first_str]
118     plot_tree.yoff = plot_tree.yoff - 1.0/plot_tree.totalD
119     for k in second_dict.keys():
120         if isinstance(second_dict[k], dict):
121             plot_tree(second_dict[k], cntr_pt, k)
122         else:
123             plot_tree.xoff = plot_tree.xoff + 1.0/plot_tree.totalW
124             plot_node(second_dict[k], (plot_tree.xoff, plot_tree.yoff), cntr_pt, leaf_node)
125             plot_mid_text((plot_tree.xoff, plot_tree.yoff), cntr_pt, k)
126     plot_tree.yoff = plot_tree.yoff + 1.0/plot_tree.totalD
127 
128 def create_plot(dtree):
129     fig = plt.figure(1, facecolor='white')
130     fig.clf()
131     axprops = dict(xticks=[], yticks=[])
132     create_plot.ax1 = plt.subplot(111, frameon=False, **axprops)
133     plot_tree.totalW = float(get_num_leafs(dtree))
134     plot_tree.totalD = float(get_tree_depth(dtree))
135     plot_tree.xoff = -0.5/plot_tree.totalW
136     plot_tree.yoff = 1.0
137     plot_tree(dtree, (0.5, 1.0), '')
138     plt.show()
139 
140 if __name__ == '__main__':
141     my_data = pd.read_csv('./watermelon2.0.csv', encoding='gbk')
142     column_count = dict([(ds, list(pd.unique(my_data[ds]))) for ds in my_data.iloc[:, :-1].columns])
143     d_tree = create_tree(my_data, column_count)
144     create_plot(d_tree)

绘制的决策树如下:

 

posted on 2017-10-16 11:28  司徒道  阅读(11759)  评论(2编辑  收藏  举报