1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | import math import pickle from matplotlib import pyplot as plt def calc_shang(dataset: list ): """ 计算给定数据集的香农熵 :param dataset: :return: """ length = len (dataset) label_count_map = {} for item in dataset: current_label = item[ - 1 ] if current_label not in label_count_map: label_count_map[current_label] = 0 label_count_map[current_label] + = 1 shang = 0.0 for label, count in label_count_map.items(): prob = count / length shang + = prob * ( - 1 * math.log(prob, 2 )) return shang def create_dataset(): dataset = [ [ 1 , 1 , "yes" ], [ 1 , 1 , "yes" ], [ 1 , 0 , "no" ], [ 0 , 1 , "no" ], [ 0 , 1 , "no" ] ] labels = [ "no surfacing" , "flippers" ] return dataset, labels def split_dataset(dataset, axis, value): new_dataset = [] for item in dataset: if item[axis] = = value: reduced_item = item[:axis] reduced_item.extend(item[axis + 1 :]) new_dataset.append(reduced_item) return new_dataset def choose_best_feature(dataset): num = len (dataset[ 0 ]) - 1 shang = calc_shang(dataset) best_info_gain = 0 best_feature = - 1 for i in range (num): feat_list = [_[i] for _ in dataset] unique_list = set (feat_list) _shang = 0 for feat in unique_list: sub_dataset = split_dataset(dataset, i, feat) prob = len (sub_dataset) / len (dataset) _shang + = prob * calc_shang(sub_dataset) info_gain = shang - _shang if info_gain > best_info_gain: best_info_gain = info_gain best_feature = i return best_feature def classify(class_list): class_count_map = {} for item in class_list: if item not in class_count_map: class_count_map[item] = 0 class_count_map[item] + = 1 sorted_class_count_map = sorted (class_count_map.items(), key = lambda x: x[ 1 ], reverse = True ) return sorted_class_count_map[ 0 ][ 0 ] def create_tree(dataset, labels): class_list = [_[ - 1 ] for _ in dataset] if class_list.count(class_list[ 0 ]) = = len (class_list): return class_list[ 0 ] best_feature = choose_best_feature(dataset) best_class_label = labels[best_feature] tree = {best_class_label: {}} del labels[best_feature] feat_values = [_[best_feature] for _ in dataset] unique_values = set (feat_values) for value in unique_values: sub_labels = labels[:] tree[best_class_label][value] = create_tree(split_dataset(dataset, best_feature, value), sub_labels) return tree def plot_tree(tree, root_name): def _plot_tree(ax, tree, parent_name, parent_x, parent_y, dx, dy): if parent_name and parent_x = = 0 and parent_y = = 0 : ax.text( 0 , 0 , parent_name, ha = 'center' , va = 'center' , bbox = dict (facecolor = 'white' , edgecolor = 'black' )) if isinstance (tree, dict ): # 遍历字典中的每个键值对 for edge_label, child in tree.items(): # 计算子节点的位置 child_x = parent_x - dx / 2 if edge_label = = 0 else parent_x + dx / 2 child_y = parent_y - dy if isinstance (child, dict ): child_name = list (child.keys())[ 0 ] else : child_name = child # 绘制边和边的描述 ax.plot([parent_x, child_x], [parent_y, child_y], 'k-' ) mid_x = (parent_x + child_x) / 2 mid_y = (parent_y + child_y) / 2 ax.text(mid_x, mid_y, str (edge_label), ha = 'center' , va = 'center' , fontsize = 8 , bbox = dict (facecolor = 'yellow' , edgecolor = 'black' )) # 绘制子节点 ax.text(child_x, child_y, child_name, ha = 'center' , va = 'center' , bbox = dict (facecolor = 'white' , edgecolor = 'black' )) # 递归绘制子树 if isinstance (child, dict ): _plot_tree(ax, child[child_name], child_name, child_x, child_y, dx / 2 , dy) fig, ax = plt.subplots(figsize = ( 10 , 8 )) ax.set_xlim( - 1 , 1 ) ax.set_ylim( - 1.5 , 0.5 ) ax.axis( 'off' ) _plot_tree(ax, tree[root_name], root_name, 0 , 0 , 1 , 0.5 ) plt.show() def classify_tree(tree: dict , labels: list , test_vec): first_str = list (tree.keys())[ 0 ] second_dict = tree[first_str] feat_index = labels.index(first_str) class_label = "" for key, value in second_dict.items(): if test_vec[feat_index] = = key: if isinstance (value, dict ): class_label = classify_tree(value, labels, test_vec) else : class_label = value return class_label def store_tree(tree: dict , file_path: str ): with open (file_path, "wb" ) as f: pickle.dump(tree, f) def grab_tree(file_path): with open (file_path, "rb" ) as f: return pickle.load(f) if __name__ = = '__main__' : mat, labels = create_dataset() tree = create_tree(dataset = mat, labels = labels) plot_tree(tree, 'no surfacing' ) |
1 | https: //gitee .com /navysummer/machine-learning/tree/master/decision_tree |
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步