1 def classify0(inX, dataSet, labels, k): 2 dataSetSize = dataSet.shape[0] 3 diffMat = tile(inX, (dataSetSize,1)) - dataSet 4 sqDiffMat = diffMat**2 5 sqDistances = sqDiffMat.sum(axis=1) 6 distances = sqDistances**0.5
7 sortedDistIndicies = distances.argsort() 8 classCount={}
9 for i in range(k): 10 voteIlabel = labels[sortedDistIndicies[i]] 11 classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 12 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) 13 return sortedClassCount[0][0]
1-6是计算欧式距离
7-8是根据距离排序,注意argsort 返回的list是索引值
9-13是计算距离最小的k个中不同label数目。
其中 tile argsort均在另外的随笔中介绍过了
classCount是一个dict ,dict.get(key,default),返回的是key为xx的个数,如果没有就返回default 值,这里设置为了0.
所以这一步是在计数。
熟悉 sort的用法。operator.itemgetter(1) 命令返回的是一个函数,key=func之后就按照func来处理每一个数据 然后再进行排序。
operator.itemgetter(1)
在导入txt文件时出现问题了。网上都说要将kNN.py 和 txt放同一个目录下,我就是啊 为什么不行!
所以以后还是老实的 C:\.....\....\datingSet.txt 吧
datingDataMat,datingLabels=kNN.file2matrix('datingTestSet2.txt') fig = plt.figure() ax = fig.add_subplot(111) ax.scatter(datingDataMat[:,1],datingDataMat[:,2],15.0*numpy.array(datingLabels),15.0*numpy.array(datingLabels)) plt.show()
scatter中有四个参数 ,前两个就是x y参数,后两个是利用他们labels的值不同 给他们赋予不同的大小和颜色。 第三个参数 s(size) 第四个 c(color)
可以简单理解为,a[1,:] 输出的是矩阵的第二行。