决策树算法
1.理论
2.代码
2.1 训练数据
RID,age,income,student,credit_rating,class:buy_computer
1,youth,high,no,fair,no
2,youth,high,no,excellent,no
3,middle_age,high,no,fair,yes
4,senior,medium,no,fair,yes
5,senior,low,yes,fair,yes
6,senior,low,yes,excellent,no
7,middle_age,low,yes,excellent,yes
8,youth,medium,no,fair,no
9,youth,low,yes,fair,yes
10,senior,medium,yes,fair,yes
11,youth,medium,yes,excellent,yes
12,middle_age,medium,no,excellent,yes
13,middle_age,high,yes,fair,yes
14,senior,medium,no,excellent,no
2.2 代码
from sklearn.feature_extraction import DictVectorizer
import csv
from sklearn import preprocessing
from sklearn import tree
from sklearn.externals.six import StringIO
#Read Data
Data = open(r'D:\lern\ML\Data.csv','r', encoding="utf-8")
reader = csv.reader(Data)
headers = next(reader)#获取第一行,特征
print(headers)
featureList = []
labelList = []
for row in reader :
labelList.append(row[len(row)-1])
rowOict ={}
for i in range(1,len(row)-1):
# print(row[i])
rowOict[headers[i]]=row[i]
# print("rowOict:",rowOict)
featureList.append(rowOict)
print(featureList)
#vectorize features
vec = DictVectorizer()
dummyX=vec.fit_transform(featureList).toarray()
print("dummyX:"+str(dummyX))
print(vec.get_feature_names())
print("labelList:"+str(labelList))
#vectorize class labels
lb= preprocessing.LabelBinarizer()
dummyY=lb.fit_transform(labelList)
print("dummyY:"+str(dummyY))
#using Dicision Tree for classification
clf =tree.DecisionTreeClassifier(criterion='entropy')
clf=clf.fit(dummyX,dummyY)
print("clf:"+str(clf))
#visulize model
with open("DT.dot",'w') as f:
f= tree.export_graphviz(clf,feature_names=vec.get_feature_names(),out_file=f)
#predict result
oneRowX=dummyX[0,:]
print("oneRowX"+str(oneRowX))
newRowX=oneRowX
newRowX[0]=1
newRowX[2]=0
print("newRowX"+str(newRowX))
predictedY=clf.predict([newRowX])
print("predictedY"+str(predictedY))
2.3 结果
digraph Tree {
node [shape=box] ;
0 [label="age=middle_age <= 0.5\nentropy = 0.94\nsamples = 14\nvalue = [5, 9]"] ;
1 [label="student=no <= 0.5\nentropy = 1.0\nsamples = 10\nvalue = [5, 5]"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="credit_rating=excellent <= 0.5\nentropy = 0.722\nsamples = 5\nvalue = [1, 4]"] ;
1 -> 2 ;
3 [label="entropy = 0.0\nsamples = 3\nvalue = [0, 3]"] ;
2 -> 3 ;
4 [label="income=low <= 0.5\nentropy = 1.0\nsamples = 2\nvalue = [1, 1]"] ;
2 -> 4 ;
5 [label="entropy = 0.0\nsamples = 1\nvalue = [0, 1]"] ;
4 -> 5 ;
6 [label="entropy = 0.0\nsamples = 1\nvalue = [1, 0]"] ;
4 -> 6 ;
7 [label="age=youth <= 0.5\nentropy = 0.722\nsamples = 5\nvalue = [4, 1]"] ;
1 -> 7 ;
8 [label="credit_rating=fair <= 0.5\nentropy = 1.0\nsamples = 2\nvalue = [1, 1]"] ;
7 -> 8 ;
9 [label="entropy = 0.0\nsamples = 1\nvalue = [1, 0]"] ;
8 -> 9 ;
10 [label="entropy = 0.0\nsamples = 1\nvalue = [0, 1]"] ;
8 -> 10 ;
11 [label="entropy = 0.0\nsamples = 3\nvalue = [3, 0]"] ;
7 -> 11 ;
12 [label="entropy = 0.0\nsamples = 4\nvalue = [0, 4]"] ;
0 -> 12 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
}
2.3.1 结果可视化:graghviz小工具