iris数据集 决策树实现分类并画出决策树

 1 # coding=utf-8
 2 
 3 import pandas as pd
 4 from sklearn.model_selection import train_test_split
 5 from sklearn import tree
 6 from sklearn.metrics import precision_recall_curve  #准确率与召回率
 7 import numpy as np
 8 #import graphviz
 9 
10 import os
11 os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz2.38/bin/'
12 
13 
14 
15 def get_data():
16     file_path = "Iris.xlsx"
17 
18     data = pd.read_excel(file_path)
19     loandata = pd.DataFrame(data)
20     ncol = (len(loandata.keys()))
21     print(ncol)
22     # l = list(data.head(0))  #获取表头
23     # print(l)
24 
25     feature1 = []
26     for i in range(ncol-1):
27         feature1.append("feature"+str(i))
28     print(feature1)
29     iris_x = data.iloc[1:, :ncol-1]#此处有冒号,不显示最后一列
30     iris_y = data.iloc[1:,ncol-1]#此处没有冒号,直接定位
31 
32     '''计算到底有几个类别'''
33     from collections import Counter
34     counter = Counter(iris_y)
35     con = len(counter)
36     print(counter.keys())
37     class_names = []
38     for i in range(con):
39         class_names.append(list(counter.keys())[i])
40     x_train, x_test, y_train, y_test = train_test_split(iris_x,iris_y)
41     print(x_train)
42     print(y_test)
43    # return x_train, x_test, y_train, y_test
44 
45 
46 #def dtfit(x_train, x_test, y_train, y_test):
47 
48     clf = tree.DecisionTreeClassifier()
49     clf = clf.fit(x_train,y_train)
50     predict_data = clf.predict(x_test)
51     predict_proba = clf.predict_proba(x_test)
52     from sklearn import metrics
53     # Do classification task,
54     # then get the ground truth and the predict label named y_true and y_pred
55     classify_report = metrics.classification_report(y_test, clf.predict(x_test))
56     confusion_matrix = metrics.confusion_matrix(y_train, clf.predict(x_train))
57     overall_accuracy = metrics.accuracy_score(y_train, clf.predict(x_train))
58     acc_for_each_class = metrics.precision_score(y_train,clf.predict(x_train), average=None)
59     overall_accuracy = np.mean(acc_for_each_class)
60     print(classify_report)
61 
62 
63 
64 
65     import pydotplus
66     dot_data = tree.export_graphviz(clf, out_file=None,feature_names=feature1, filled=True, rounded=True, special_characters=True,precision = 4)
67     graph = pydotplus.graph_from_dot_data(dot_data)
68     graph.write_pdf("workiris.pdf")
69     return classify_report
70 
71 
72 if __name__ == "__main__":
73     x = get_data()
74     #dtfit(x_train, x_test, y_train, y_test)

数据地址:http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data

保存后注意填写表头

posted @ 2018-01-05 15:47  shizhenqiang  阅读(4644)  评论(0编辑  收藏  举报