学习knn算法笔记
定义
kNN == k-NearestNeighbor k个最近的邻居
核心思想——如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。
最大特点——kNN方法在类别决策时,只与极少量的相邻样本有关。
适用情况——类域的交叉或重叠较多的待分样本集
例子
已知条件如下
一、聪明人用蓝色方块 表示
二、笨人用红色三角形 表示
三、有个村庄,里面只有两种人:聪明人&&笨人,凭据是小学第一次语文数学考试成绩,但是具体规则只有村长知道
四、一天村里来了一个外乡人,但是村长不在,村民们就炸了毛,不知道ta是聪明人?笨人?
五、从户籍转入申请书上,我们获得了外乡人的语文数学考试成绩 问 外乡算什么人??
算法描述
1)计算测试数据和各个训练数据之间的距离
2)按照距离的递增关系进行排序
3)选取距离最小的k的个点
4)确定前k个点所有类别的出现频率
5)返回前k个点中出现频率最高的类别作为测试数据的预测分类
衍生
怎么选出最接近的k个人?
计算出全体已分类成员与待分类成员的欧式距离(公式: ),
取距离最小的前k个 怎么选出最恰当的k值?
首先确定一系列k值(一般从1取到min(20,已分类成员总数开方)),在已分类的成员里面选出一小部分作为待分类成员,拿到取不同k值时的误差率,将误差最小的k值视为最恰当k值
1 import numpy as np 2 import matplotlib.pyplot as plt 3 4 5 amount = 12 6 wtnum = [5,7] 7 bgnums = np.zeros((amount,3)) 8 bgnums[0] = [17, 7, 1] 9 bgnums[1] = [4, 12, 2] 10 bgnums[2] = [5, 17, 2] 11 bgnums[3] = [8, 11, 1] 12 bgnums[4] = [13, 9, 1] 13 bgnums[5] = [16, 9, 2] 14 bgnums[6] = [14, 7, 1] 15 bgnums[7] = [10, 15, 2] 16 bgnums[8] = [15, 18, 2] 17 bgnums[9] = [11, 17, 2] 18 bgnums[10] = [13, 19, 1] 19 bgnums[11] = [12, 10, 2] 20 colmap = {1:'r',2:'g'} 21 for i in range(amount): 22 plt.scatter(bgnums[i][0],bgnums[i][1],color = colmap[bgnums[i][2]]) 23 plt.scatter(wtnum[0],wtnum[1],color = 'k',alpha = 0.8) 24 plt.xlim(0,20) 25 plt.ylim(0,20) 26 plt.show()
distance = np.zeros((amount,2)) for i in range(amount): distance[i][0] = (bgnums[i][0] - wtnum[0])**2 + (bgnums[i][1] - wtnum[1])**2 distance[i][1] = bgnums[i][2] print(distance) for i in range(amount): for j in range(amount - 1): if distance[j][0] > distance[j+1][0]: distance[j][0],distance[j+1][0] = distance[j+1][0],distance[j][0] distance[j][1],distance[j+1][1] = distance[j+1][1],distance[j][1] print(distance)
k = 3 num1 = 0 num2 = 0 for i in range(k): if distance[i][1] == 1: num1 += 1 elif distance[i][1] == 2: num2 += 1 if num1 >= num2: print("该测试点归为第1类") elif num1 < num2: print("该测试点归为第2类") else: print("something2 wrong has happened")
import numpy as np import matplotlib.pyplot as plt def belong_which( bgnums, amount, wtnum, k ): for i in range(amount): plt.scatter(bgnums[i][0],bgnums[i][1],color = colmap[bgnums[i][2]]) plt.scatter(wtnum[0],wtnum[1],color = 'k',alpha = 0.8) plt.xlim(0,20) plt.ylim(0,20) plt.show() distance = np.zeros((amount,2)) for i in range(amount): distance[i][0] = (bgnums[i][0] - wtnum[0])**2 + (bgnums[i][1] - wtnum[1])**2 distance[i][1] = bgnums[i][2] print(distance) for i in range(amount): for j in range(amount - 1): if distance[j][0] > distance[j+1][0]: distance[j][0],distance[j+1][0] = distance[j+1][0],distance[j][0] distance[j][1],distance[j+1][1] = distance[j+1][1],distance[j][1] print(distance) num1 = 0 num2 = 0 for i in range(k): if distance[i][1] == 1: num1 += 1 elif distance[i][1] == 2: num2 += 1 if num1 >= num2: print("该测试点归为第1类") elif num1 < num2: print("该测试点归为第2类") else: print("something2 wrong has happened") amount = 12 wtnum = [5,7] bgnums = np.zeros((amount,3)) bgnums[0] = [17, 7, 1] bgnums[1] = [4, 12, 2] bgnums[2] = [5, 17, 2] bgnums[3] = [8, 11, 1] bgnums[4] = [13, 9, 1] bgnums[5] = [16, 9, 2] bgnums[6] = [14, 7, 1] bgnums[7] = [10, 15, 2] bgnums[8] = [15, 18, 2] bgnums[9] = [11, 17, 2] bgnums[10] = [13, 19, 1] bgnums[11] = [12, 10, 2] colmap = {1:'r',2:'g'} k = 3 belong_which( bgnums, amount, wtnum, k )
import numpy as np import matplotlib.pyplot as plt def belong_which( bgnums, amount, wtnum, k ): distance = np.zeros((amount,2)) for i in range(amount): distance[i][0] = (bgnums[i][0] - wtnum[0])**2 + (bgnums[i][1] - wtnum[1])**2 distance[i][1] = bgnums[i][2] for i in range(amount): for j in range(amount - 1): if distance[j][0] > distance[j+1][0]: distance[j][0],distance[j+1][0] = distance[j+1][0],distance[j][0] distance[j][1],distance[j+1][1] = distance[j+1][1],distance[j][1] num1 = 0 num2 = 0 num3 = 0 for i in range(k): if distance[i][1] == 1: num1 += 1 elif distance[i][1] == 2: num2 += 1 elif distance[i][1] == 3: num3 += 1 if num1 >= num2: if num1 >= num3: #print("该测试点归为第1类") return 1 else: if num1 < num3: #print("该测试点归为第3类") return 3 else: print("something wrong has happened") else: if num2 >= num3: #print("该测试点归为第2类") return 2 else: if num2 < num3: #print("该测试点归为第3类") return 3 else: print("something2 wrong has happened") amount = 26 wtnum = [12,16] bgnums = np.zeros((amount,3)) bgnums[0] = [84, 91, 1] bgnums[1] = [55, 82, 2] bgnums[2] = [81, 70, 1] bgnums[3] = [70, 80, 1] bgnums[4] = [75, 75, 1] bgnums[5] = [75, 74, 2] bgnums[6] = [85, 99, 1] bgnums[7] = [56, 28, 3] bgnums[8] = [96, 100, 1] bgnums[9] = [31, 33, 3] bgnums[10] = [62, 35, 3] bgnums[11] = [96, 76, 1] bgnums[12] = [39, 55, 3] bgnums[13] = [60, 64, 2] bgnums[14] = [75, 75, 1] bgnums[15] = [74, 73, 2] bgnums[16] = [43, 47, 3] bgnums[17] = [56, 71, 2] bgnums[18] = [90, 89, 1] bgnums[19] = [45, 67, 2] bgnums[20] = [77, 88, 1] bgnums[21] = [65, 55, 2] bgnums[22] = [55, 45, 3] bgnums[23] = [78, 85, 1] bgnums[24] = [44, 32, 3] bgnums[25] = [77, 98, 1] colmap = {1:'r',2:'g',3:'b'} k_num = 5 test_num = 6 amount_t = amount - test_num wtnum_t = np.zeros((test_num,2)) bgnums_t = np.zeros((amount_t,3)) error = np.zeros(test_num) error_rate = np.zeros(k_num) for i in range(test_num, amount): plt.scatter(bgnums[i][0], bgnums[i][1], color = colmap[bgnums[i][2]]) for i in range(test_num): plt.scatter(bgnums[i][0], bgnums[i][1], s = 10, color = 'k', alpha = 0.5) plt.xlim(0,100) plt.ylim(0,100) plt.show() for i in range(test_num): wtnum_t[i][0] = bgnums[i][0] wtnum_t[i][1] = bgnums[i][1] for i in range(amount_t): bgnums_t[i] = bgnums[i + test_num] start_num = 1 for j in range(start_num, start_num + k_num): print("\n\n现在的k值为", j) for i in range(test_num): print("\n现在的待归类村民成绩为", wtnum_t[i][0], wtnum_t[i][1]) m = belong_which(bgnums_t, amount_t, wtnum_t[i], j) print("根据knn算法,该村民属于", m, ",实际上该村民属于", bgnums[i][2],) if m == bgnums[i][2]: error[i] = 0 else: error[i] = 1 print("k值为", j, "时,error的列表为", error) print("--------------------------------------") error_rate[j - start_num] = sum(error)/test_num print("error_rate is", error_rate) print("错误率最低的k为", error_rate.tolist().index(min(error_rate))+start_num)
现在的k值为 1 现在的待归类村民成绩为 84.0 91.0 根据knn算法,该村民属于 1 ,实际上该村民属于 1.0 现在的待归类村民成绩为 55.0 82.0 根据knn算法,该村民属于 2 ,实际上该村民属于 2.0 现在的待归类村民成绩为 81.0 70.0 根据knn算法,该村民属于 2 ,实际上该村民属于 1.0 现在的待归类村民成绩为 70.0 80.0 根据knn算法,该村民属于 1 ,实际上该村民属于 1.0 现在的待归类村民成绩为 75.0 75.0 根据knn算法,该村民属于 1 ,实际上该村民属于 1.0 现在的待归类村民成绩为 75.0 74.0 根据knn算法,该村民属于 1 ,实际上该村民属于 2.0 k值为 1 时,error的列表为 [0. 0. 1. 0. 0. 1.] -------------------------------------- 现在的k值为 2 现在的待归类村民成绩为 84.0 91.0 根据knn算法,该村民属于 1 ,实际上该村民属于 1.0 现在的待归类村民成绩为 55.0 82.0 根据knn算法,该村民属于 2 ,实际上该村民属于 2.0 现在的待归类村民成绩为 81.0 70.0 根据knn算法,该村民属于 1 ,实际上该村民属于 1.0 现在的待归类村民成绩为 70.0 80.0 根据knn算法,该村民属于 1 ,实际上该村民属于 1.0 现在的待归类村民成绩为 75.0 75.0 根据knn算法,该村民属于 1 ,实际上该村民属于 1.0 现在的待归类村民成绩为 75.0 74.0 根据knn算法,该村民属于 1 ,实际上该村民属于 2.0 k值为 2 时,error的列表为 [0. 0. 0. 0. 0. 1.] -------------------------------------- 现在的k值为 3 现在的待归类村民成绩为 84.0 91.0 根据knn算法,该村民属于 1 ,实际上该村民属于 1.0 现在的待归类村民成绩为 55.0 82.0 根据knn算法,该村民属于 2 ,实际上该村民属于 2.0 现在的待归类村民成绩为 81.0 70.0 根据knn算法,该村民属于 1 ,实际上该村民属于 1.0 现在的待归类村民成绩为 70.0 80.0 根据knn算法,该村民属于 1 ,实际上该村民属于 1.0 现在的待归类村民成绩为 75.0 75.0 根据knn算法,该村民属于 1 ,实际上该村民属于 1.0 现在的待归类村民成绩为 75.0 74.0 根据knn算法,该村民属于 1 ,实际上该村民属于 2.0 k值为 3 时,error的列表为 [0. 0. 0. 0. 0. 1.] -------------------------------------- 现在的k值为 4 现在的待归类村民成绩为 84.0 91.0 根据knn算法,该村民属于 1 ,实际上该村民属于 1.0 现在的待归类村民成绩为 55.0 82.0 根据knn算法,该村民属于 2 ,实际上该村民属于 2.0 现在的待归类村民成绩为 81.0 70.0 根据knn算法,该村民属于 1 ,实际上该村民属于 1.0 现在的待归类村民成绩为 70.0 80.0 根据knn算法,该村民属于 1 ,实际上该村民属于 1.0 现在的待归类村民成绩为 75.0 75.0 根据knn算法,该村民属于 1 ,实际上该村民属于 1.0 现在的待归类村民成绩为 75.0 74.0 根据knn算法,该村民属于 1 ,实际上该村民属于 2.0 k值为 4 时,error的列表为 [0. 0. 0. 0. 0. 1.] -------------------------------------- 现在的k值为 5 现在的待归类村民成绩为 84.0 91.0 根据knn算法,该村民属于 1 ,实际上该村民属于 1.0 现在的待归类村民成绩为 55.0 82.0 根据knn算法,该村民属于 2 ,实际上该村民属于 2.0 现在的待归类村民成绩为 81.0 70.0 根据knn算法,该村民属于 1 ,实际上该村民属于 1.0 现在的待归类村民成绩为 70.0 80.0 根据knn算法,该村民属于 1 ,实际上该村民属于 1.0 现在的待归类村民成绩为 75.0 75.0 根据knn算法,该村民属于 1 ,实际上该村民属于 1.0 现在的待归类村民成绩为 75.0 74.0 根据knn算法,该村民属于 1 ,实际上该村民属于 2.0 k值为 5 时,error的列表为 [0. 0. 0. 0. 0. 1.] -------------------------------------- error_rate is [0.33333333 0.16666667 0.16666667 0.16666667 0.16666667] 错误率最低的k为 2