K--NN(近邻)模型算法
运行平台: Windows
Python版本: Python3.x
1.1 k-近邻法简介
k近邻法(k-nearest neighbor, k-NN)是1967年由Cover T和Hart P提出的一种基本分类与回归方法。它的工作原理是:存在一个样本数据集合,也称作为训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一个数据与所属分类的对应关系。输入没有标签的新数据后,将新的数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本最相似数据(最近邻)的分类标签。一般来说,我们只选择样本数据集中前k个最相似的数据,这就是k-近邻算法中k的出处,通常k是不大于20的整数。最后,选择k个最相似数据中出现次数最多的分类,作为新数据的分类。
举个简单的例子,我们可以使用k-近邻算法分类一个电影是爱情片还是动作片。
电影名称 | 打斗镜头 | 接吻镜头 | 电影类型 |
电影1 | 1 | 101 | 爱情片 |
电影2 | 5 | 89 | 爱情片 |
电影3 | 108 | 5 | 动作片 |
电影4 | 115 | 8 | 动作片 |
这个数据集有两个特征,即打斗镜头数和接吻镜头数。除此之外,我们也知道每个电影的所属类型,即分类标签。用肉眼粗略地观察,接吻镜头多的,是爱情片。打斗镜头多的,是动作片。以我们多年的看片经验,这个分类还算合理。如果现在给我一部电影,你告诉我这个电影打斗镜头数和接吻镜头数。不告诉我这个电影类型,我可以根据你给我的信息进行判断,这个电影是属于爱情片还是动作片。而k-近邻算法也可以像我们人一样做到这一点,不同的地方在于,我们的经验更”牛逼”,而k-邻近算法是靠已有的数据。比如,你告诉我这个电影打斗镜头数为2,接吻镜头数为102,我的经验会告诉你这个是爱情片,k-近邻算法也会告诉你这个是爱情片。你又告诉我另一个电影打斗镜头数为49,接吻镜头数为51,我”邪恶”的经验可能会告诉你,这有可能是个”爱情动作片”,画面太美,我不敢想象。 (如果说,你不知道”爱情动作片”是什么?请评论留言与我联系,我需要你这样像我一样纯洁的朋友。) 但是k-近邻算法不会告诉你这些,因为在它的眼里,电影类型只有爱情片和动作片,它会提取样本集中特征最相似数据(最邻近)的分类标签,得到的结果可能是爱情片,也可能是动作片,但绝不会是”爱情动作片”。当然,这些取决于数据集的大小以及最近邻的判断标准等因素。
那么k-邻近算法是什么呢?k-近邻算法步骤如下:
- 计算已知类别数据集中的点与当前点之间的距离;
- 按照距离递增次序排序;
- 选取与当前点距离最小的k个点;
- 确定前k个点所在类别的出现频率;
- 返回前k个点所出现频率最高的类别作为当前点的预测分类。
比如,现在我这个k值取3,那么在电影例子中,按距离依次排序的三个点分别是动作片(108,5)、动作片(115,8)、爱情片(5,89)。在这三个点中,动作片出现的频率为三分之二,爱情片出现的频率为三分之一,所以该红色圆点标记的电影为动作片。这个判别过程就是k-近邻算法。
1.2 创建kNN_test01.py文件,K-NN的k值为3,编写代码如下:
1 # -*- coding: UTF-8 -*- 2 import numpy as np 3 import operator 4 5 def createDataSet(): 6 #四组二维特征 7 group = np.array([[1,101],[5,89],[108,5],[115,8]]) 8 #四组特征的标签 9 labels = ['爱情片','爱情片','动作片','动作片'] 10 return group, labels 11 12 """ 13 函数说明:kNN算法,分类器 14 15 Parameters: 16 inX - 用于分类的数据(测试集) 17 dataSet - 用于训练的数据(训练集) 18 labes - 分类标签 19 k - kNN算法参数,选择距离最小的k个点 20 Returns: 21 sortedClassCount[0][0] - 分类结果 22 23 Modify: 24 2017-07-13 25 """ 26 def classify0(inX, dataSet, labels, k): 27 #numpy函数shape[0]返回dataSet的行数 28 dataSetSize = dataSet.shape[0] 29 #在列向量方向上重复inX共1次(横向),行向量方向上重复inX共dataSetSize次(纵向) 30 diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet 31 #二维特征相减后平方 32 sqDiffMat = diffMat**2 33 #sum()所有元素相加,sum(0)列相加,sum(1)行相加 34 sqDistances = sqDiffMat.sum(axis=1) 35 #开方,计算出距离 36 distances = sqDistances**0.5 37 #返回distances中元素从小到大排序后的索引值 38 sortedDistIndices = distances.argsort() 39 #定一个记录类别次数的字典 40 classCount = {} 41 for i in range(k): 42 #取出前k个元素的类别 43 voteIlabel = labels[sortedDistIndices[i]] 44 #dict.get(key,default=None),字典的get()方法,返回指定键的值,如果值不在字典中返回默认值。 45 #计算类别次数 46 classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 47 #python3中用items()替换python2中的iteritems() 48 #key=operator.itemgetter(1)根据字典的值进行排序 49 #key=operator.itemgetter(0)根据字典的键进行排序 50 #reverse降序排序字典 51 sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True) 52 #返回次数最多的类别,即所要分类的类别 53 return sortedClassCount[0][0] 54 55 if __name__ == '__main__': 56 #创建数据集 57 group, labels = createDataSet() 58 #测试集 59 test = [101,20] 60 #kNN分类 61 test_class = classify0(test, group, labels, 3) 62 #打印分类结果 63 print(test_class)
测试结果:
print(test_class)为动作片。
2.1 k-近邻算法实战之约会网站配对效果判定
上一小结学习了简单的k-近邻算法的实现方法,但是这并不是完整的k-近邻算法流程,k-近邻算法的一般流程:
- 收集数据:可以使用爬虫进行数据的收集,也可以使用第三方提供的免费或收费的数据。一般来讲,数据放在txt文本文件中,按照一定的格式进行存储,便于解析及处理。
- 准备数据:使用Python解析、预处理数据。
- 分析数据:可以使用很多方法对数据进行分析,例如使用Matplotlib将数据可视化。
- 测试算法:计算错误率。
- 使用算法:错误率在可接受范围内,就可以运行k-近邻算法进行分类。
已经了解了k-近邻算法的一般流程,下面开始进入实战内容。
2.2 实战背景
海伦女士一直使用在线约会网站寻找适合自己的约会对象。尽管约会网站会推荐不同的任选,但她并不是喜欢每一个人。经过一番总结,她发现自己交往过的人可以进行如下分类:
- 不喜欢的人
- 魅力一般的人
- 极具魅力的人
海伦收集约会数据已经有了一段时间,她把这些数据存放在文本文件datingTestSet.txt中,每个样本数据占据一行,总共有1000行。
海伦收集的样本数据主要包含以下3种特征:
- 每年获得的飞行常客里程数
- 玩视频游戏所消耗时间百分比
- 每周消费的冰淇淋公升数
didntLike:不喜欢 largeDoses:极具魅力 smallDOses:魅力一般
数据下载地址:https://github.com/Jack-Cherish/Machine-Learning/tree/master/kNN/2.%E6%B5%B7%E4%BC%A6%E7%BA%A6%E4%BC%9A
2.3 分析数据:数据可视化
在kNN_test02.py文件中编写名为showdatas的函数,用来将数据可视化
1 import matplotlib.lines as mlines 2 import matplotlib.pyplot as plt 3 import numpy as np 4 5 6 """ 7 函数说明:打开并解析文件,对数据进行分类:1代表不喜欢,2代表魅力一般,3代表极具魅力 8 9 Parameters: 10 filename - 文件名 11 Returns: 12 returnMat - 特征矩阵 13 classLabelVector - 分类Label向量 14 15 Modify: 16 2017-12-28 17 """ 18 def file2matrix(filename): 19 #打开文件 20 fr = open(filename) 21 #读取文件所有内容 22 arrayOLines = fr.readlines() 23 #得到文件行数 24 numberOfLines = len(arrayOLines) 25 #返回的NumPy矩阵,解析完成的数据:numberOfLines行,3列 26 returnMat = np.zeros((numberOfLines,3)) 27 #返回的分类标签向量 28 classLabelVector = [] 29 #行的索引值 30 index = 0 31 for line in arrayOLines: 32 #s.strip(rm),当rm空时,默认删除空白符(包括'\n','\r','\t',' ') 33 line = line.strip() 34 #使用s.split(str="",num=string,cout(str))将字符串根据'\t'分隔符进行切片。 35 listFromLine = line.split('\t') 36 #将数据前三列提取出来,存放到returnMat的NumPy矩阵中,也就是特征矩阵 37 returnMat[index,:] = listFromLine[0:3] 38 #根据文本中标记的喜欢的程度进行分类,1代表不喜欢,2代表魅力一般,3代表极具魅力 39 if listFromLine[-1] == 'didntLike': 40 classLabelVector.append(1) 41 elif listFromLine[-1] == 'smallDoses': 42 classLabelVector.append(2) 43 elif listFromLine[-1] == 'largeDoses': 44 classLabelVector.append(3) 45 index += 1 46 #print(returnMat) 47 return returnMat, classLabelVector 48 49 def showdatas(returnMat, classLabelVector): 50 #设置汉字格式 51 plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 52 plt.rcParams['font.serif'] = ['SimHei'] 53 plt.rcParams['axes.unicode_minus'] = False # 用控制中文乱码 54 #将fig画布分隔成1行1列,不共享x轴和y轴,fig画布的大小为(13,8) 55 #当nrow=2,nclos=2时,代表fig画布被分为四个区域,axs[0][0]表示第一行第一个区域 56 fig, axs = plt.subplots(nrows=2, ncols=2,sharex=False, sharey=False, figsize=(13,8)) 57 numberOfLabels = len(classLabelVector) 58 LabelsColors = [] 59 for i in classLabelVector: 60 if i == 1: 61 LabelsColors.append('black') 62 if i == 2: 63 LabelsColors.append('orange') 64 if i == 3: 65 LabelsColors.append('red') 66 #画出散点图,以returnMat矩阵的第一(飞行常客例程)、第二列(玩游戏)数据画散点数据,散点大小为15,透明度为0.5 67 axs[0][0].scatter(x=returnMat[:,0], y=returnMat[:,1], color=LabelsColors,s=30, alpha=.5) 68 #设置标题,x轴label,y轴label 69 axs0_title_text = axs[0][0].set_title('每年获得的飞行常客里程数与玩视频游戏所消耗时间占比') 70 axs0_xlabel_text = axs[0][0].set_xlabel('每年获得的飞行常客里程数') 71 axs0_ylabel_text = axs[0][0].set_ylabel('玩视频游戏所消耗时间占') 72 plt.setp(axs0_title_text, size=12, weight='bold', color='red') 73 plt.setp(axs0_xlabel_text, size=9, weight='bold', color='black') 74 plt.setp(axs0_ylabel_text, size=9, weight='bold', color='black') 75 76 #画出散点图,以returnMat矩阵的第一(飞行常客例程)、第三列(冰激凌)数据画散点数据,散点大小为15,透明度为0.5 77 axs[0][1].scatter(x=returnMat[:,0], y=returnMat[:,2], color=LabelsColors,s=30, alpha=.5) 78 #设置标题,x轴label,y轴label 79 axs1_title_text = axs[0][1].set_title('每年获得的飞行常客里程数与每周消费的冰激淋公升数') 80 axs1_xlabel_text = axs[0][1].set_xlabel('每年获得的飞行常客里程数') 81 axs1_ylabel_text = axs[0][1].set_ylabel('每周消费的冰激淋公升数') 82 plt.setp(axs1_title_text, size=12, weight='bold', color='red') 83 plt.setp(axs1_xlabel_text, size=9, weight='bold', color='black') 84 plt.setp(axs1_ylabel_text, size=9, weight='bold', color='black') 85 86 #画出散点图,以returnMat矩阵的第二(玩游戏)、第三列(冰激凌)数据画散点数据,散点大小为15,透明度为0.5 87 axs[1][0].scatter(x=returnMat[:,1], y=returnMat[:,2], color=LabelsColors,s=30, alpha=.5) 88 #设置标题,x轴label,y轴label 89 axs2_title_text = axs[1][0].set_title('玩视频游戏所消耗时间占比与每周消费的冰激淋公升数') 90 axs2_xlabel_text = axs[1][0].set_xlabel('玩视频游戏所消耗时间占比') 91 axs2_ylabel_text = axs[1][0].set_ylabel('每周消费的冰激淋公升数') 92 plt.setp(axs2_title_text, size=12, weight='bold', color='red') 93 plt.setp(axs2_xlabel_text, size=9, weight='bold', color='black') 94 plt.setp(axs2_ylabel_text, size=9, weight='bold', color='black') 95 #设置图例 96 didntLike = mlines.Line2D([], [], color='black', marker='.', 97 markersize=6, label='didntLike') 98 smallDoses = mlines.Line2D([], [], color='orange', marker='.', 99 markersize=6, label='smallDoses') 100 largeDoses = mlines.Line2D([], [], color='red', marker='.', 101 markersize=6, label='largeDoses') 102 #添加图例 103 axs[0][0].legend(handles=[didntLike,smallDoses,largeDoses], loc='upper right') 104 axs[0][1].legend(handles=[didntLike,smallDoses,largeDoses], loc='upper right') 105 axs[1][0].legend(handles=[didntLike,smallDoses,largeDoses], loc='upper right') 106 #显示图片 107 plt.show() 108 109 """ 110 函数说明:main函数 111 112 Parameters: 113 无 114 Returns: 115 无 116 117 Modify: 118 2017-12-28 119 """ 120 if __name__ == '__main__': 121 #打开的文件名 122 filename = "datingTestSet.txt" 123 #打开并处理数据 124 returnMat, classLabelVector = file2matrix(filename) 125 showdatas(returnMat, classLabelVector)
运行结果:
通过数据可以很直观的发现数据的规律,比如以玩游戏所消耗时间占比与每年获得的飞行常客里程数,只考虑这二维的特征信息,给我的感觉就是海伦喜欢有生活质量的男人。为什么这么说呢?每年获得的飞行常客里程数表明,海伦喜欢能享受飞行常客奖励计划的男人,但是不能经常坐飞机,疲于奔波,满世界飞。同时,这个男人也要玩视频游戏,并且占一定时间比例。能到处飞,又能经常玩游戏的男人是什么样的男人?很显然,有生活质量,并且生活悠闲的人。我的分析,仅仅是通过可视化的数据总结的个人看法。我想,每个人的感受应该也是不尽相同。
2.4 使用算法:构建完整可用系统
我们可以给海伦一个小段程序,通过该程序海伦会在约会网站上找到某个人并输入他的信息。程序会给出她对男方喜欢程度的预测值。
代码如下:
1 #!/usr/bin/env python 2 # _*_ coding:utf-8 _*_ 3 4 import numpy as np 5 import operator 6 7 """ 8 函数说明:kNN算法,分类器 9 10 Parameters: 11 inX - 用于分类的数据(测试集) 12 dataSet - 用于训练的数据(训练集) 13 labes - 分类标签 14 k - kNN算法参数,选择距离最小的k个点 15 Returns: 16 sortedClassCount[0][0] - 分类结果 17 18 Modify: 19 2017-12-28 20 """ 21 def classify0(inX, dataSet, labels, k): 22 #numpy函数shape[0]返回dataSet的行数 23 dataSetSize = dataSet.shape[0] 24 #在列向量方向上重复inX共1次(横向),行向量方向上重复inX共dataSetSize次(纵向) 25 diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet 26 #二维特征相减后平方 27 sqDiffMat = diffMat**2 28 #sum()所有元素相加,sum(0)列相加,sum(1)行相加 29 sqDistances = sqDiffMat.sum(axis=1) 30 #开方,计算出距离 31 distances = sqDistances**0.5 32 #返回distances中元素从小到大排序后的索引值 33 sortedDistIndices = distances.argsort() 34 #定一个记录类别次数的字典 35 classCount = {} 36 for i in range(k): 37 #取出前k个元素的类别 38 voteIlabel = labels[sortedDistIndices[i]] 39 #dict.get(key,default=None),字典的get()方法,返回指定键的值,如果值不在字典中返回默认值。 40 #计算类别次数 41 classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 42 #python3中用items()替换python2中的iteritems() 43 #key=operator.itemgetter(1)根据字典的值进行排序 44 #key=operator.itemgetter(0)根据字典的键进行排序 45 #reverse降序排序字典 46 sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True) 47 #返回次数最多的类别,即所要分类的类别 48 return sortedClassCount[0][0] 49 50 51 """ 52 函数说明:打开并解析文件,对数据进行分类:1代表不喜欢,2代表魅力一般,3代表极具魅力 53 54 Parameters: 55 filename - 文件名 56 Returns: 57 returnMat - 特征矩阵 58 classLabelVector - 分类Label向量 59 60 Modify: 61 2017-12-28 62 """ 63 def file2matrix(filename): 64 #打开文件 65 fr = open(filename) 66 #读取文件所有内容 67 arrayOLines = fr.readlines() 68 #得到文件行数 69 numberOfLines = len(arrayOLines) 70 #返回的NumPy矩阵,解析完成的数据:numberOfLines行,3列 71 returnMat = np.zeros((numberOfLines,3)) 72 #返回的分类标签向量 73 classLabelVector = [] 74 #行的索引值 75 index = 0 76 for line in arrayOLines: 77 #s.strip(rm),当rm空时,默认删除空白符(包括'\n','\r','\t',' ') 78 line = line.strip() 79 #使用s.split(str="",num=string,cout(str))将字符串根据'\t'分隔符进行切片。 80 listFromLine = line.split('\t') 81 #将数据前三列提取出来,存放到returnMat的NumPy矩阵中,也就是特征矩阵 82 returnMat[index,:] = listFromLine[0:3] 83 #根据文本中标记的喜欢的程度进行分类,1代表不喜欢,2代表魅力一般,3代表极具魅力 84 if listFromLine[-1] == 'didntLike': 85 classLabelVector.append(1) 86 elif listFromLine[-1] == 'smallDoses': 87 classLabelVector.append(2) 88 elif listFromLine[-1] == 'largeDoses': 89 classLabelVector.append(3) 90 index += 1 91 return returnMat, classLabelVector 92 93 """ 94 函数说明:对数据进行归一化 95 96 Parameters: 97 dataSet - 特征矩阵 98 Returns: 99 normDataSet - 归一化后的特征矩阵 100 ranges - 数据范围 101 minVals - 数据最小值 102 103 Modify: 104 2017-12-28 105 """ 106 def autoNorm(dataSet): 107 #获得数据的最小值 108 minVals = dataSet.min(0) 109 maxVals = dataSet.max(0) 110 #最大值和最小值的范围 111 ranges = maxVals - minVals 112 #shape(dataSet)返回dataSet的矩阵行列数 113 normDataSet = np.zeros(np.shape(dataSet)) 114 #返回dataSet的行数 115 m = dataSet.shape[0] 116 #原始值减去最小值 117 normDataSet = dataSet - np.tile(minVals, (m, 1)) 118 #除以最大和最小值的差,得到归一化数据 119 normDataSet = normDataSet / np.tile(ranges, (m, 1)) 120 #返回归一化数据结果,数据范围,最小值 121 return normDataSet, ranges, minVals 122 123 """ 124 函数说明:通过输入一个人的三维特征,进行分类输出 125 126 Parameters: 127 无 128 Returns: 129 无 130 131 Modify: 132 2017-12-28 133 """ 134 def classifyPerson(): 135 #输出结果 136 resultList = ['讨厌','有些喜欢','非常喜欢'] 137 #三维特征用户输入 138 precentTats = float(input("每年获得的飞行常客里程数:")) 139 ffMiles = float(input("玩视频游戏所耗时间百分比:")) 140 iceCream = float(input("每周消费的冰激淋公升数:")) 141 #打开的文件名 142 filename = "datingTestSet.txt" 143 #打开并处理数据 144 datingDataMat, datingLabels = file2matrix(filename) 145 #训练集归一化 146 normMat, ranges, minVals = autoNorm(datingDataMat) 147 #生成NumPy数组,测试集 148 inArr = np.array([precentTats, ffMiles, iceCream]) 149 150 #测试集归一化 151 norminArr = (inArr - minVals) / ranges 152 153 #返回分类结果 154 classifierResult = classify0(norminArr, normMat, datingLabels, 3) 155 156 #打印结果 157 print("你可能%s这个人" % (resultList[classifierResult-1])) 158 159 """ 160 函数说明:main函数 161 162 Parameters: 163 无 164 Returns: 165 无 166 167 Modify: 168 2017-12-28 169 """ 170 if __name__ == '__main__': 171 classifyPerson()
并输入数据(44000,12,,0.5),预测结果是”你可能有些喜欢这个人”,也就是这个人魅力一般。一共有三个档次:讨厌、有些喜欢、非常喜欢,对应着不喜欢的人、魅力一般的人、极具魅力的人
3.1 kNN算法的优缺点
优点
- 简单好用,容易理解,精度高,理论成熟,既可以用来做分类也可以用来做回归;
- 可用于数值型数据和离散型数据;
- 训练时间复杂度为O(n);无数据输入假定;
- 对异常值不敏感。
缺点:
- 计算复杂性高;空间复杂性高;
- 样本不平衡问题(即有些类别的样本数量很多,而其它样本的数量很少);
- 一般数值很大的时候不用这个,计算量太大。但是单个样本又不能太少,否则容易发生误分。
- 最大的缺点是无法给出数据的内在含义。
博客转至:http://blog.csdn.net/c406495762/article/details/75172850
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 一个奇形怪状的面试题:Bean中的CHM要不要加volatile?
· [.NET]调用本地 Deepseek 模型
· .NET 10 首个预览版发布,跨平台开发与性能全面提升
· 全程使用 AI 从 0 到 1 写了个小工具
· 快收藏!一个技巧从此不再搞混缓存穿透和缓存击穿
· AI 插件第二弹,更强更好用
· Blazor Hybrid适配到HarmonyOS系统