点击查看代码
from sklearn import datasets # 自带数据集
from sklearn.model_selection import train_test_split # 数据集划分
from sklearn.preprocessing import StandardScaler # 标准化
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import cross_val_score
iris = datasets.load_iris()
# iris
X, target = iris.data, iris.target
def knn_model_train(X, y):
# 采用诗词交叉验证法对k近邻算法进行超参数选择
X_train, X_test, y_train, y_test = train_test_split(X, target, test_size=0.3,
random_state=111, shuffle=True, stratify=y)
k_range = range(1, 31) # 超参数选择区间
cv_scores = [] # 存储每次调参的10折交叉验证精度均值
for k in k_range:
knn = KNeighborsClassifier(k) # k值
# 聚类算法的指标 "accuracy"
# https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter
scores = cross_val_score(
knn, X_train, y_train, scoring="accuracy", cv=10)
# print(score)
cv_scores.append(scores.mean())
return cv_scores
def plt_knn_scores(cv_scores):
plt.figure(figsize=(8, 6))
plt.plot(cv_scores, "-o")
plt.xlabel("knn-k value", fontsize=12)
plt.ylabel("mean accuracy", fontsize=12)
plt.grid()
plt.title("KNN super-args of mean accuracy", fontsize=14)
plt.show()
cv_scores = knn_model_train(X, target)
plt_knn_scores(cv_scores)
点击查看代码
# 最佳k值是13
X_train, X_test, y_train, y_test = train_test_split(X, target, test_size=0.3,
random_state=111, shuffle=True, stratify=target)
k = 13
knn_best = KNeighborsClassifier(k)
knn_best.fit(X_train, y_train)
print("泛化精度是,%.5f" % knn_best.score(X_test, y_test))