机器学习之决策树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import math
import pickle
 
from matplotlib import pyplot as plt
 
 
def calc_shang(dataset: list):
    """
    计算给定数据集的香农熵
    :param dataset:
    :return:
    """
    length = len(dataset)
    label_count_map = {}
    for item in dataset:
        current_label = item[-1]
        if current_label not in label_count_map:
            label_count_map[current_label] = 0
        label_count_map[current_label] += 1
    shang = 0.0
    for label, count in label_count_map.items():
        prob = count / length
        shang += prob * (-1 * math.log(prob, 2))
    return shang
 
 
def create_dataset():
    dataset = [
        [1, 1, "yes"],
        [1, 1, "yes"],
        [1, 0, "no"],
        [0, 1, "no"],
        [0, 1, "no"]
    ]
    labels = ["no surfacing", "flippers"]
    return dataset, labels
 
 
def split_dataset(dataset, axis, value):
    new_dataset = []
    for item in dataset:
        if item[axis] == value:
            reduced_item = item[:axis]
            reduced_item.extend(item[axis + 1:])
            new_dataset.append(reduced_item)
    return new_dataset
 
 
def choose_best_feature(dataset):
    num = len(dataset[0]) - 1
    shang = calc_shang(dataset)
    best_info_gain = 0
    best_feature = -1
    for i in range(num):
        feat_list = [_[i] for _ in dataset]
        unique_list = set(feat_list)
        _shang = 0
        for feat in unique_list:
            sub_dataset = split_dataset(dataset, i, feat)
            prob = len(sub_dataset) / len(dataset)
            _shang += prob * calc_shang(sub_dataset)
        info_gain = shang - _shang
        if info_gain > best_info_gain:
            best_info_gain = info_gain
            best_feature = i
    return best_feature
 
 
def classify(class_list):
    class_count_map = {}
    for item in class_list:
        if item not in class_count_map:
            class_count_map[item] = 0
        class_count_map[item] += 1
    sorted_class_count_map = sorted(class_count_map.items(), key=lambda x: x[1], reverse=True)
    return sorted_class_count_map[0][0]
 
 
def create_tree(dataset, labels):
    class_list = [_[-1] for _ in dataset]
    if class_list.count(class_list[0]) == len(class_list):
        return class_list[0]
    best_feature = choose_best_feature(dataset)
    best_class_label = labels[best_feature]
    tree = {best_class_label: {}}
    del labels[best_feature]
    feat_values = [_[best_feature] for _ in dataset]
    unique_values = set(feat_values)
    for value in unique_values:
        sub_labels = labels[:]
        tree[best_class_label][value] = create_tree(split_dataset(dataset, best_feature, value), sub_labels)
    return tree
 
 
def plot_tree(tree, root_name):
    def _plot_tree(ax, tree, parent_name, parent_x, parent_y, dx, dy):
        if parent_name and parent_x == 0 and parent_y == 0:
            ax.text(0, 0, parent_name, ha='center', va='center', bbox=dict(facecolor='white', edgecolor='black'))
        if isinstance(tree, dict):
            # 遍历字典中的每个键值对
            for edge_label, child in tree.items():
                # 计算子节点的位置
                child_x = parent_x - dx / 2 if edge_label == 0 else parent_x + dx / 2
                child_y = parent_y - dy
 
                if isinstance(child, dict):
                    child_name = list(child.keys())[0]
                else:
                    child_name = child
 
                # 绘制边和边的描述
                ax.plot([parent_x, child_x], [parent_y, child_y], 'k-')
                mid_x = (parent_x + child_x) / 2
                mid_y = (parent_y + child_y) / 2
                ax.text(mid_x, mid_y, str(edge_label), ha='center', va='center', fontsize=8,
                        bbox=dict(facecolor='yellow', edgecolor='black'))
 
                # 绘制子节点
                ax.text(child_x, child_y, child_name, ha='center', va='center',
                        bbox=dict(facecolor='white', edgecolor='black'))
 
                # 递归绘制子树
                if isinstance(child, dict):
                    _plot_tree(ax, child[child_name], child_name, child_x, child_y, dx / 2, dy)
 
    fig, ax = plt.subplots(figsize=(10, 8))
    ax.set_xlim(-1, 1)
    ax.set_ylim(-1.5, 0.5)
    ax.axis('off')
 
    _plot_tree(ax, tree[root_name], root_name, 0, 0, 1, 0.5)
 
    plt.show()
 
 
def classify_tree(tree: dict, labels: list, test_vec):
    first_str = list(tree.keys())[0]
    second_dict = tree[first_str]
    feat_index = labels.index(first_str)
    class_label = ""
    for key, value in second_dict.items():
        if test_vec[feat_index] == key:
            if isinstance(value, dict):
                class_label = classify_tree(value, labels, test_vec)
            else:
                class_label = value
    return class_label
 
 
def store_tree(tree: dict, file_path: str):
    with open(file_path, "wb") as f:
        pickle.dump(tree, f)
 
 
def grab_tree(file_path):
    with open(file_path, "rb") as f:
        return pickle.load(f)
 
 
if __name__ == '__main__':
    mat, labels = create_dataset()
    tree = create_tree(dataset=mat, labels=labels)
    plot_tree(tree, 'no surfacing')

其他决策树示例或者基于主流机器学习框架实现的决策树代码地址:

1
https://gitee.com/navysummer/machine-learning/tree/master/decision_tree

  

posted @   NAVYSUMMER  阅读(9)  评论(0编辑  收藏  举报
交流群 编程书籍
点击右上角即可分享
微信分享提示