使用 CART 算法来创建分类树
原创转载请注明出处:https://www.cnblogs.com/agilestyle/p/12718314.html
CART 分类树
CART 分类树实际上是基于基尼系数来做属性划分的。在 Python 的 sklearn 中,如果想要创建 CART 分类树,可以直接使用 DecisionTreeClassifier 这个类。创建这个类的时候,默认情况下 criterion 这个参数等于 gini,也就是按照基尼系数来选择属性划分,即默认采用的是 CART 分类树。
准备数据
from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier iris = load_iris() features = iris.data labels = iris.target # (150, 4) features.shape # (150,) labels.shape
分割训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.33, random_state=0)
建模训练
clf = DecisionTreeClassifier(criterion='gini') # DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None, # max_features=None, max_leaf_nodes=None, # min_impurity_decrease=0.0, min_impurity_split=None, # min_samples_leaf=1, min_samples_split=2, # min_weight_fraction_leaf=0.0, presort=False, # random_state=None, splitter='best') clf.fit(X_train, y_train)
评估模型
train_score = clf.score(X_train, y_train) test_score = clf.score(X_test, y_test) # 1.0 train_score # 0.96 test_score
决策树可视化
from sklearn.tree import export_graphviz with open('iris.dot', 'w') as f: f = export_graphviz(clf, out_file=f)
Note: 需要安装 graphviz
brew install graphviz cp iris.dot /usr/local/Cellar/graphviz/2.42.3/bin dot -Tpng iris.dot -o iris.png cp iris.png ~/Desktop
执行上述脚本后,可以得到下面的图示
Reference
https://time.geekbang.org/column/article/78659
强者自救 圣者渡人