from sklearn.datasets import load_iris import numpy as np import matplotlib.pyplot as plt iris = load_iris() iris_data = iris_target = print(iris.feature_names) X = iris_data[:,0:2] y = iris_data[:,3] #['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'] #We'll try to predict the petal length based on the sepal length and width. #We'll also fit a regular linear regression to see how well the k-NN regression does in comparison #线性回归 from sklearn.linear_model import LinearRegression lr = LinearRegression(), y) print ("The MSE is: {:.2}".format(np.power(y - lr.predict(X),2).mean())) #K-NN 回归 from sklearn.neighbors import KNeighborsRegressor knnr = KNeighborsRegressor(n_neighbors=10), y) print ("The MSE is: {:.2}".format(np.power(y - knnr.predict(X),2).mean())) #仅仅显示预测函数如何使用而已 print(knnr.predict(np.array([3.0,5.0]).reshape(1,-1))) #Let's look at what the k-NN regression does when we tell it to use the closest 10 points for regression: f, ax = plt.subplots(nrows=2, figsize=(7, 10)) ax[0].set_title("Predictions") ax[0].scatter(X[:, 0], X[:, 1], s=lr.predict(X)*80, label='LRPredictions', color='c', edgecolors='black') ax[1].scatter(X[:, 0], X[:, 1], s=knnr.predict(X)*80, label='k-NNPredictions', color='m', edgecolors='black') ax[0].legend() ax[1].legend() #针对某一个类别(KNN的效果优于线性) setosa_idx = np.where(iris.target_names=='setosa') setosa_mask = ( == setosa_idx[0]) print(y[setosa_mask][:20]) print(knnr.predict(X)[setosa_mask][:20]) print(lr.predict(X)[setosa_mask][:20]) #针对某一个具体的点 #The k-NN regression is very simply calculated taking the average of the k closest point to the point being tested. #Let's manually predict a single point: example_point = X[0] ''' 原始真值 >>> X[0] array([ 5.1, 3.5]) >>> y[0] 0.20000000000000001 ''' from sklearn.metrics import pairwise distances_to_example = pairwise.pairwise_distances(X)[0] #X[0]和其它150个元素(包括自己)的距离 ten_closest_points = X[np.argsort(distances_to_example)][:10] #排序后,寻找10个距离最小的索引 ten_closest_y = y[np.argsort(distances_to_example)][:10]#所这些最下的10个已知数找出来 print(ten_closest_y.mean()) #We can see that this is very close to what was expected.
['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
The MSE is: 0.15
The MSE is: 0.069
[ 0.2]
[ 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 0.2 0.2 0.1 0.1 0.2
0.4 0.4 0.3 0.3 0.3]
[ 0.28 0.17 0.21 0.2 0.31 0.27 0.21 0.31 0.19 0.17 0.29 0.28
0.17 0.19 0.26 0.27 0.27 0.28 0.27 0.31]
[ 0.44636645 0.53893889 0.29846368 0.27338255 0.32612885 0.47403161
0.13064785 0.42128532 0.22322028 0.49136065 0.56918808 0.27596658
0.46627952 0.10298268 0.71709085 0.45411854 0.47403161 0.44636645
0.73958795 0.30363175]
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步