08.手写KNN算法测试
导入库
import numpy as np from sklearn import datasets import matplotlib.pyplot as plt
导入数据
iris = datasets.load_iris()
数据准备
X = iris.data y = iris.target X.shape, y.shape
((150, 4), (150,))
数据分割(28开)
# 因为训练集矩阵和标签向量是分割的,不能单独对某一个进行乱序
# 需要将其合并整体乱序再分割
X_join_y = np.hstack([X, y.reshape(-1,1)])
# 随机,导致每次数据分割结果都会改变
# 如果有debug需求,需要保证每次运行的分割结果一致
# 则需要对random进行seed设置
np.random.seed(1)
np.random.shuffle(X_join_y)
train,test = np.vsplit(X_join_y, [int(0.8*len(X_join_y))])
train.shape,test.shape
((120, 5), (30, 5))
准备data和target
# X_train, y_train, X_test, y_test 成功拿到了训练集(数据+标签)和测试集(数据+标签)
X_train = train[:,0:4] y_train = train[:,-1] X_test = test[:,0:4] y_test = test[:,-1]
KNN手写算法
import numpy as np from math import sqrt from collections import Counter
class KNNClassifier: def __init__(self, k): # 初始化KNN分类器 self.k = k self._X_train = None self._y_train = None def fit(self, X_train, y_train): # 根据训练集X_train, Y_train训练分类器 self._X_train = X_train self._y_train = y_train return self def predict(self, X_predict): # 给定待遇测的数据集X_predict,返回表示X_predict的结果向量 y_predict = [self._predict(x) for x in X_predict] return np.array(y_predict) def _predict(self, x): # 给定单个待遇测数据x,返回x的预测结果值 distances = [sqrt(np.sum((x_train - x) ** 2)) for x_train in self._X_train] nearest = np.argsort(distances) topK_y = [self._y_train[i] for i in nearest[:self.k]] votes = Counter(topK_y) return votes.most_common(1)[0][0] def __repr__(self): return "KNN=(%d)" % self.k
from sklearn.model_selection import train_test_split result = train_test_split(X, y) result
[array([[7.2, 3. , 5.8, 1.6], [5.4, 3.9, 1.3, 0.4], [6.5, 3.2, 5.1, 2. ], [6.1, 3. , 4.6, 1.4], [4.6, 3.2, 1.4, 0.2], [6.9, 3.2, 5.7, 2.3], [6.1, 2.8, 4. , 1.3], [5.7, 3. , 4.2, 1.2], [5.8, 2.7, 4.1, 1. ], [5.5, 2.5, 4. , 1.3], [5.7, 2.5, 5. , 2. ], [4.6, 3.4, 1.4, 0.3], [5.9, 3.2, 4.8, 1.8], [6.3, 2.9, 5.6, 1.8], [6.8, 3. , 5.5, 2.1], [6.4, 2.7, 5.3, 1.9], [6. , 2.9, 4.5, 1.5], [6. , 2.2, 4. , 1. ], [4.8, 3. , 1.4, 0.1], [5.6, 2.5, 3.9, 1.1], [7.1, 3. , 5.9, 2.1], [6.7, 3.3, 5.7, 2.1], [5.5, 2.6, 4.4, 1.2], [6.3, 3.3, 4.7, 1.6], [6.7, 3.1, 4.7, 1.5], [4.3, 3. , 1.1, 0.1], [4.8, 3.4, 1.9, 0.2], [6.7, 3.3, 5.7, 2.5], [6. , 2.7, 5.1, 1.6], [6.5, 3. , 5.5, 1.8], [4.9, 2.5, 4.5, 1.7], [5. , 3.5, 1.3, 0.3], [5.9, 3. , 4.2, 1.5], [5.5, 2.4, 3.8, 1.1], [6.2, 2.2, 4.5, 1.5], [6.3, 2.7, 4.9, 1.8], [4.4, 3. , 1.3, 0.2], [7.7, 3. , 6.1, 2.3], [7. , 3.2, 4.7, 1.4], [6.4, 2.8, 5.6, 2.2], [5.7, 2.8, 4.5, 1.3], [6.4, 2.9, 4.3, 1.3], [5.6, 3. , 4.1, 1.3], [6.3, 2.8, 5.1, 1.5], [4.9, 3.6, 1.4, 0.1], [6. , 3.4, 4.5, 1.6], [5.7, 4.4, 1.5, 0.4], [4.8, 3. , 1.4, 0.3], [5.4, 3.7, 1.5, 0.2], [5.4, 3.4, 1.5, 0.4], [5. , 2.3, 3.3, 1. ], [6.9, 3.1, 4.9, 1.5], [5.1, 3.8, 1.9, 0.4], [6.4, 2.8, 5.6, 2.1], [5.1, 3.8, 1.5, 0.3], [5. , 3.4, 1.5, 0.2], [5.1, 3.3, 1.7, 0.5], [5.2, 2.7, 3.9, 1.4], [6.1, 2.6, 5.6, 1.4], [7.7, 2.8, 6.7, 2. ], [5.8, 2.7, 5.1, 1.9], [6.8, 2.8, 4.8, 1.4], [4.4, 3.2, 1.3, 0.2], [5.3, 3.7, 1.5, 0.2], [6.9, 3.1, 5.4, 2.1], [5.1, 2.5, 3. , 1.1], [5.7, 2.8, 4.1, 1.3], [6.4, 3.1, 5.5, 1.8], [6.2, 3.4, 5.4, 2.3], [5.8, 2.7, 5.1, 1.9], [6.3, 2.5, 4.9, 1.5], [5.8, 2.6, 4. , 1.2], [4.6, 3.1, 1.5, 0.2], [4.9, 3.1, 1.5, 0.2], [5.6, 2.9, 3.6, 1.3], [5.1, 3.7, 1.5, 0.4], [5. , 3.2, 1.2, 0.2], [6.5, 3. , 5.8, 2.2], [7.3, 2.9, 6.3, 1.8], [5.2, 3.4, 1.4, 0.2], [4.5, 2.3, 1.3, 0.3], [5.5, 2.3, 4. , 1.3], [6.5, 3. , 5.2, 2. ], [5.5, 2.4, 3.7, 1. ], [7.6, 3. , 6.6, 2.1], [5. , 3.6, 1.4, 0.2], [5.9, 3. , 5.1, 1.8], [6.3, 2.5, 5. , 1.9], [6.1, 3. , 4.9, 1.8], [4.9, 3. , 1.4, 0.2], [6.7, 3. , 5.2, 2.3], [5.1, 3.5, 1.4, 0.3], [6.3, 2.3, 4.4, 1.3], [4.4, 2.9, 1.4, 0.2], [6.8, 3.2, 5.9, 2.3], [5.1, 3.8, 1.6, 0.2], [7.2, 3.6, 6.1, 2.5], [5.7, 3.8, 1.7, 0.3], [5. , 2. , 3.5, 1. ], [5. , 3. , 1.6, 0.2], [4.8, 3.4, 1.6, 0.2], [4.8, 3.1, 1.6, 0.2], [6.7, 3.1, 5.6, 2.4], [5.8, 2.8, 5.1, 2.4], [5.8, 4. , 1.2, 0.2], [6.1, 2.8, 4.7, 1.2], [5.4, 3.9, 1.7, 0.4], [6.5, 2.8, 4.6, 1.5], [4.9, 3.1, 1.5, 0.1], [5.4, 3.4, 1.7, 0.2], [4.9, 2.4, 3.3, 1. ], [5.1, 3.4, 1.5, 0.2]]), array([[6.2, 2.9, 4.3, 1.3], [6.7, 3. , 5. , 1.7], [5.2, 4.1, 1.5, 0.1], [5.7, 2.6, 3.5, 1. ], [7.4, 2.8, 6.1, 1.9], [5.6, 3. , 4.5, 1.5], [6.9, 3.1, 5.1, 2.3], [6. , 2.2, 5. , 1.5], [5.5, 3.5, 1.3, 0.2], [6.7, 2.5, 5.8, 1.8], [7.2, 3.2, 6. , 1.8], [6. , 3. , 4.8, 1.8], [5.2, 3.5, 1.5, 0.2], [5.1, 3.5, 1.4, 0.2], [5. , 3.3, 1.4, 0.2], [5.6, 2.8, 4.9, 2. ], [5.6, 2.7, 4.2, 1.3], [5. , 3.5, 1.6, 0.6], [7.9, 3.8, 6.4, 2. ], [6.3, 3.4, 5.6, 2.4], [5. , 3.4, 1.6, 0.4], [6.2, 2.8, 4.8, 1.8], [5.4, 3. , 4.5, 1.5], [5.5, 4.2, 1.4, 0.2], [4.6, 3.6, 1. , 0.2], [6.1, 2.9, 4.7, 1.4], [6.4, 3.2, 5.3, 2.3], [5.7, 2.9, 4.2, 1.3], [7.7, 2.6, 6.9, 2.3], [7.7, 3.8, 6.7, 2.2], [6.3, 3.3, 6. , 2.5], [5.8, 2.7, 3.9, 1.2], [6.6, 2.9, 4.6, 1.3], [4.7, 3.2, 1.6, 0.2], [6.7, 3.1, 4.4, 1.4], [6.4, 3.2, 4.5, 1.5], [4.7, 3.2, 1.3, 0.2], [6.6, 3. , 4.4, 1.4]]), array([2, 0, 2, 1, 0, 2, 1, 1, 1, 1, 2, 0, 1, 2, 2, 2, 1, 1, 0, 1, 2, 2, 1, 1, 1, 0, 0, 2, 1, 2, 2, 0, 1, 1, 1, 2, 0, 2, 1, 2, 1, 1, 1, 2, 0, 1, 0, 0, 0, 0, 1, 1, 0, 2, 0, 0, 0, 1, 2, 2, 2, 1, 0, 0, 2, 1, 1, 2, 2, 2, 1, 1, 0, 0, 1, 0, 0, 2, 2, 0, 0, 1, 2, 1, 2, 0, 2, 2, 2, 0, 2, 0, 1, 0, 2, 0, 2, 0, 1, 0, 0, 0, 2, 2, 0, 1, 0, 1, 0, 0, 1, 0]), array([1, 1, 0, 1, 2, 1, 2, 2, 0, 2, 2, 2, 0, 0, 0, 2, 1, 0, 2, 2, 0, 2, 1, 0, 0, 1, 2, 1, 2, 2, 2, 1, 1, 0, 1, 1, 0, 1])]
my_knn_clf = KNNClassifier(k=3)
my_knn_clf.fit(result[0], result[2])
KNN=(3)
y_predict = my_knn_clf.predict(result[1]) sum(y_predict == result[3]) sum(y_predict == result[3])/len(result[3])