iris数据集 决策树实现分类并画出决策树
1 # coding=utf-8 2 3 import pandas as pd 4 from sklearn.model_selection import train_test_split 5 from sklearn import tree 6 from sklearn.metrics import precision_recall_curve #准确率与召回率 7 import numpy as np 8 #import graphviz 9 10 import os 11 os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz2.38/bin/' 12 13 14 15 def get_data(): 16 file_path = "Iris.xlsx" 17 18 data = pd.read_excel(file_path) 19 loandata = pd.DataFrame(data) 20 ncol = (len(loandata.keys())) 21 print(ncol) 22 # l = list(data.head(0)) #获取表头 23 # print(l) 24 25 feature1 = [] 26 for i in range(ncol-1): 27 feature1.append("feature"+str(i)) 28 print(feature1) 29 iris_x = data.iloc[1:, :ncol-1]#此处有冒号,不显示最后一列 30 iris_y = data.iloc[1:,ncol-1]#此处没有冒号,直接定位 31 32 '''计算到底有几个类别''' 33 from collections import Counter 34 counter = Counter(iris_y) 35 con = len(counter) 36 print(counter.keys()) 37 class_names = [] 38 for i in range(con): 39 class_names.append(list(counter.keys())[i]) 40 x_train, x_test, y_train, y_test = train_test_split(iris_x,iris_y) 41 print(x_train) 42 print(y_test) 43 # return x_train, x_test, y_train, y_test 44 45 46 #def dtfit(x_train, x_test, y_train, y_test): 47 48 clf = tree.DecisionTreeClassifier() 49 clf = clf.fit(x_train,y_train) 50 predict_data = clf.predict(x_test) 51 predict_proba = clf.predict_proba(x_test) 52 from sklearn import metrics 53 # Do classification task, 54 # then get the ground truth and the predict label named y_true and y_pred 55 classify_report = metrics.classification_report(y_test, clf.predict(x_test)) 56 confusion_matrix = metrics.confusion_matrix(y_train, clf.predict(x_train)) 57 overall_accuracy = metrics.accuracy_score(y_train, clf.predict(x_train)) 58 acc_for_each_class = metrics.precision_score(y_train,clf.predict(x_train), average=None) 59 overall_accuracy = np.mean(acc_for_each_class) 60 print(classify_report) 61 62 63 64 65 import pydotplus 66 dot_data = tree.export_graphviz(clf, out_file=None,feature_names=feature1, filled=True, rounded=True, special_characters=True,precision = 4) 67 graph = pydotplus.graph_from_dot_data(dot_data) 68 graph.write_pdf("workiris.pdf") 69 return classify_report 70 71 72 if __name__ == "__main__": 73 x = get_data() 74 #dtfit(x_train, x_test, y_train, y_test)
数据地址:http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
保存后注意填写表头