决策树预测莺尾花数据

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn import tree

# 生成所有测试样本点
def make_meshgrid(x, y, h=.02):
    x_min, x_max = x.min() - 1, x.max() + 1
    y_min, y_max = y.min() - 1, y.max() + 1
    #生成所有的测试点,比如[[1,1],[2,2]] [[1,2],[1,2]]
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                         np.arange(y_min, y_max, h))
    return xx, yy

# 对测试样本进行预测,并显示
def plot_test_results(ax, clf, xx, yy, **params):
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    ax.contourf(xx, yy, Z, **params)


if __name__ == '__main__':
    # 载入iris数据集
    iris = datasets.load_iris()
    # 只使用前面两个特征
    X = iris.data[:, :2]
    # 样本标签值
    y = iris.target

    # 创建并训练决策树
    clf = tree.DecisionTreeClassifier()
    clf.fit(X,y)

    title = ('DecisionTreeClassifier')

    fig, ax = plt.subplots(figsize = (5, 5))
    plt.subplots_adjust(wspace=0.4, hspace=0.4)

    X0, X1 = X[:, 0], X[:, 1]
    # 生成所有测试样本点
    xx, yy = make_meshgrid(X0, X1)
    #print(xx)

    # 显示测试样本的分类结果
    plot_test_results(ax, clf, xx, yy, cmap=plt.cm.coolwarm, alpha=0.8)
    # 显示训练样本
    ax.scatter(X0, X1, c=y, cmap=plt.cm.coolwarm, s=20, edgecolors='k')
    ax.set_xlim(xx.min(), xx.max())
    ax.set_ylim(yy.min(), yy.max())
    ax.set_xlabel('x1')
    ax.set_ylabel('x2')
    ax.set_xticks(())
    ax.set_yticks(())
    ax.set_title(title)
    plt.show()

 

posted @ 2020-01-08 15:08  喵小喵~  阅读(385)  评论(0编辑  收藏  举报