《机器学习实战》第2章——k近邻算法(笔记)

一、KNN(k-近邻算法)工作原理

一句话:从训练集中找出k个最接近测试数据的训练样本,再从这k个样本中找出出现次数最多的分类,作为测试数据的分类。

存在一个样本数据集合(训练样本集)且样本集中每个数据都存在标签(即我们知道样本集中每一数据与所属分类的对应关系);

输入没有标签的数据后,将该数据的每个特征与样本集中数据对应的特征进行比较,并提取样本集中特征最相似数据(最近邻)的分类标签;

选择样本集中前k个最相似的数据(这就是k-近邻算法中k的出处,通常k是不大于20的整数);

选择k个最相似数据中出现次数最多的分类,作为新数据的分类。

二、优缺点

优点:

1、理论成熟、思想简单、易理解和实现;

2、可用于分类(包括非线性分类)、回归;

3、计算时间和空间线性于训练集的规模(训练时间复杂度比支持向量机之类的算法低,O(n))

4、适合单标签多分类和多标签分类问题;

5、对于类域的交叉或重叠较多的待分类样本集较为适合;

 

缺点:

1、计算量大(尤其是特征数比较多的时候)

2、对不平衡数据集(数据集中各个类别的样本量极不均衡)效果差(可采用加权投票法改进)

3、k值的选择对分类效果有很大影响(较小的话对噪声敏感,需估计最佳k值)

4、可解释性不强

三、代码

代码来自于《机器学习实践》,添加了一些小注释和一些测试代码

  1 #encoding:utf-8
  2 '''
  3 Created on Sep 16, 2010
  4 kNN: k Nearest Neighbors
  5 
  6 Input:      inX: vector to compare to existing dataset (1xN)
  7             dataSet: size m data set of known vectors (NxM)
  8             labels: data set labels (1xM vector)
  9             k: number of neighbors to use for comparison (should be an odd number)
 10             
 11 Output:     the most popular class label
 12 
 13 @author: pbharrin
 14 '''
 15 from numpy import *
 16 import operator
 17 from os import listdir
 18 
 19 #inX:用于分类的输入向量,即将对其进行分类
 20 #dataSet:训练样本集
 21 #labels:标签向量
 22 def classify0(inX, dataSet, labels, k):
 23     dataSetSize = dataSet.shape[0]  # 得到数组的行数。即知道有几个训练数据
 24     diffMat = tile(inX, (dataSetSize,1)) -  dataSet #tile将原来的一个数组,行数扩充dataSetSize个,列数不扩充。diffMat得到了目标与训练数值之间的差值
 25     sqDiffMat = diffMat**2  #各个元素分别平方
 26     sqDistances = sqDiffMat.sum(axis=1) # 对沿着轴1的方向进行数据相加处理,即得到一个每一个距离的平方
 27     distances = sqDistances**0.5 # 开方的距离
 28     sortedDistIndices = distances.argsort()    #升序排列
 29     classCount={} # 定义字典
 30     for i in range(k):
 31         voteIlabel = labels[sortedDistIndices[i]] # 从排序好的list中依次获取索引,并根据该索引,获得相应距离对应的标签值
 32         classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 # 从字典中,获取该标签值对应的统计数,若还没有标签值,则取默认值0,并且+1
 33     sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
 34     return sortedClassCount[0][0]
 35 
 36 def createDataSet():
 37     group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
 38     labels = ['A','A','B','B']
 39     return group, labels
 40 
 41 def file2matrix(filename):
 42     fr = open(filename)
 43     numberOfLines = len(fr.readlines())         #get the number of lines in the file
 44     returnMat = zeros((numberOfLines,3))        #prepare matrix to return
 45     print(type(returnMat))
 46     classLabelVector = []                       #prepare labels return   
 47     fr = open(filename)
 48     index = 0
 49     for line in fr.readlines():
 50         line = line.strip()
 51         listFromLine = line.split('\t')
 52         returnMat[index,:] = listFromLine[0:3] # numpy数组赋值方式
 53         classLabelVector.append(int(listFromLine[-1]))
 54         index += 1
 55     return returnMat,classLabelVector
 56     
 57 def autoNorm(dataSet):
 58     minVals = dataSet.min(0) # 0 表示纵轴,1表示横轴,与matrix刚好相反。注意这里保留了所有列的平均值
 59     maxVals = dataSet.max(0)
 60     ranges = maxVals - minVals
 61     normDataSet = zeros(shape(dataSet))
 62     m = dataSet.shape[0]
 63     normDataSet = dataSet - tile(minVals, (m,1))
 64     normDataSet = normDataSet/tile(ranges, (m,1))   #element wise divide
 65     return normDataSet, ranges, minVals
 66    
 67 def datingClassTest():
 68     hoRatio = 0.10      #hold out 10%
 69     datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')       #load data setfrom file
 70     normMat, ranges, minVals = autoNorm(datingDataMat)
 71     m = normMat.shape[0]
 72     numTestVecs = int(m*hoRatio)
 73     errorCount = 0.0
 74     for i in range(numTestVecs):
 75         classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
 76         print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]))
 77         if (classifierResult != datingLabels[i]): errorCount += 1.0
 78     print("the total error rate is: %f" % (errorCount/float(numTestVecs)))
 79     print(errorCount)
 80     
 81 def img2vector(filename):
 82     returnVect = zeros((1,1024))
 83     fr = open(filename)
 84     for i in range(32):
 85         lineStr = fr.readline()
 86         for j in range(32):
 87             returnVect[0,32*i+j] = int(lineStr[j])
 88     return returnVect
 89 
 90 def handwritingClassTest():
 91     hwLabels = []
 92     trainingFileList = listdir('trainingDigits')           #load the training set
 93     m = len(trainingFileList)
 94     trainingMat = zeros((m,1024))
 95     for i in range(m):
 96         fileNameStr = trainingFileList[i]
 97         fileStr = fileNameStr.split('.')[0]     #take off .txt
 98         classNumStr = int(fileStr.split('_')[0])
 99         hwLabels.append(classNumStr)
100         trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
101     testFileList = listdir('testDigits')        #iterate through the test set
102     errorCount = 0.0
103     mTest = len(testFileList)
104     for i in range(mTest):
105         fileNameStr = testFileList[i]
106         fileStr = fileNameStr.split('.')[0]     #take off .txt
107         classNumStr = int(fileStr.split('_')[0])
108         vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
109         classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
110         print("the classifier cam e back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
111         if (classifierResult != classNumStr): errorCount += 1.0
112     print ("\nthe total number of errors is: %d" % errorCount)
113     print ("\nthe total error rate is: %f" % (errorCount/float(mTest)))
114 
115 
116 if __name__ == '__main__':
117     trainingMat = zeros((10,4))
118     print(trainingMat)
119     m = len(trainingMat)
120     for i in range(m):
121         trainingMat[i,:] = [1,2,3,4]
122         print ('-----------------')
123         print (trainingMat)
124 
125     a = array([[3,2,3,4],[3,4,5,6]])
126     a = a**2
127     print (a)
128     print ('-----------')
129     sqDistances = a.sum(axis=0) # 按y轴(纵轴)相加
130     print (sqDistances)
131     sqDistances2 = a.sum(axis=1) # 沿x轴(横轴)相加
132     print (sqDistances2)
133     distances = sqDistances**0.5
134     print (distances)
135     print (distances.argsort())
136     print (distances[0])
137 
138     x = array([1,4,3,-1,6,9])
139     print (x.argsort()[-1])
140     y = x.argsort()
141     print (x)
142     print (y)
143     print (x[y[0]])
144 
145     classCount = {0:3, 5:2, 4:6}
146     print(classCount)
147     sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
148     print(sortedClassCount)
149     print (sortedClassCount[0][0])
150 
151     handwritingClassTest() # 手写字体分类测试 
152     datingClassTest() # 约会对象分类测试

 

四、绘制里程数与玩视频游戏所占比例的数据散点图

代码如下:

'''
Created on Oct 6, 2010

@author: Peter
'''
from numpy import *
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle


n = 1000 #number of points to create
xcord1 = []; ycord1 = []
xcord2 = []; ycord2 = []
xcord3 = []; ycord3 = []
markers =[]
colors =[]
fw = open('testSet.txt','w')
for i in range(n):
    [r0,r1] = random.standard_normal(2) # 正态分布函数中区数据个数
    myClass = random.uniform(0,1) # 从0到1中取随机值,将0.66%(即66分以上的,标识为喜欢)
    if (myClass <= 0.16): # 
        fFlyer = random.uniform(22000, 60000) # 自定义里程数
        tats = 3 + 1.6*r1 # 自定义玩视频游戏时间的函数(这个蛮有意思的)
        markers.append(20) # 大小定义
        colors.append(2.1) # 颜色定义
        classLabel = 1 #'didntLike'
        xcord1.append(fFlyer); ycord1.append(tats)
    elif ((myClass > 0.16) and (myClass <= 0.33)):
        fFlyer = 6000*r0 + 70000
        tats = 10 + 3*r1 + 2*r0
        markers.append(20)
        colors.append(1.1)
        classLabel = 1 #'didntLike'
        if (tats < 0): tats =0 # 异常值处理
        if (fFlyer < 0): fFlyer =0 # 异常值处理
        xcord1.append(fFlyer); ycord1.append(tats)
    elif ((myClass > 0.33) and (myClass <= 0.66)):
        fFlyer = 5000*r0 + 10000
        tats = 3 + 2.8*r1
        markers.append(30)
        colors.append(1.1)
        classLabel = 2 #'smallDoses'
        if (tats < 0): tats =0
        if (fFlyer < 0): fFlyer =0
        xcord2.append(fFlyer); ycord2.append(tats)
    else:
        fFlyer = 10000*r0 + 35000 # 自定义:里程多又爱玩游戏的超级喜欢(里程多,所以会玩)
        tats = 10 + 2.0*r1 # 自定义:里程多又爱玩游戏的超级喜欢(玩游戏时间多,可以和自己一起玩)
        markers.append(50)
        colors.append(0.1)
        classLabel = 3 #'largeDoses'
        if (tats < 0): tats =0
        if (fFlyer < 0): fFlyer =0
        xcord3.append(fFlyer); ycord3.append(tats)    

fw.close()
fig = plt.figure()
ax = fig.add_subplot(111)
#ax.scatter(xcord,ycord, c=colors, s=markers)
type1 = ax.scatter(xcord1, ycord1, s=20, c='red')
type2 = ax.scatter(xcord2, ycord2, s=30, c='green')
type3 = ax.scatter(xcord3, ycord3, s=50, c='blue')
ax.legend([type1, type2, type3], ["Did Not Like", "Liked in Small Doses", "Liked in Large Doses"], loc=2)
ax.axis([-5000,100000,-2,25]) # 定义x轴和y轴的起始点
plt.xlabel('Frequent Flyier Miles Earned Per Year') # 定义x轴标签描述
plt.ylabel('Percentage of Time Spent Playing Video Games') # 定义x轴标签描述
plt.show()

 

 

参考:数据挖掘-各种分类算法的优缺点

posted @ 2020-06-22 10:35  绍荣  阅读(236)  评论(0编辑  收藏  举报