【机器学习实战】k-近邻算法2.2约会网站预测函数
《机器学习实战》学习
书中使用Python2进行代码演示,我这里将其转换为Python3,并做了一些注释。要学会使用断点调试,方便很多
下面的代码是书中2.2节使用k-近邻算法改进约会网站的配对效果的完整测试代码:
1 from numpy import * 2 import operator 3 import matplotlib 4 import matplotlib.pyplot as plt 5 6 7 def createDataSet(): 8 group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]]) 9 labels = ['A', 'A', 'B', 'B'] 10 return group, labels 11 12 13 def classify0(inX, dataSet, labels, k): 14 ''' 15 k-近邻算法 16 :param inX:用于分类的输入向量 17 :param dataSet: 输入的训练样本集 18 :param labels: 标签向量 19 :param k: 用于选择最近邻居的数目 20 :return: 返回k个邻居中距离最近且数量最多的类别作为预测类别 21 ''' 22 dataSetSize = dataSet.shape[0] 23 diffMat = tile(inX, (dataSetSize, 1)) - dataSet 24 sqDiffMat = diffMat ** 2 25 sqDistances = sqDiffMat.sum(axis=1) 26 distances = sqDistances ** 0.5 27 # 以上为计算输入向量与已有标签样本的欧式距离 28 sortedDistIndicies = distances.argsort() # argsort函数返回的是数组值从小到大的索引值,距离需要从小到大排序 29 classCount = {} 30 for i in range(k): 31 voteIlabel = labels[sortedDistIndicies[i]] 32 classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 33 # Python 字典(Dictionary) get() 函数返回指定键的值,如果值不在字典中返回默认值。 34 # get(voteIlabel,0)表示当能查询到相匹配的字典时,就会显示相应key对应的value,如果不能的话,就会显示后面的这个参数。 35 sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True) 36 # 按照元祖中第2个值的大小降序排序 37 # python2中的iteritems()方法需改为items() 38 return sortedClassCount[0][0] 39 40 41 def file2matrix(filename): 42 # 将文本记录转换为NumPy的解析程序 43 fr = open(filename) 44 arrayOLines = fr.readlines() 45 numberOfLines = len(arrayOLines) 46 print(numberOfLines) 47 returnMat = zeros((numberOfLines,3)) # 存放3种特征 48 classLabelVector = [] # 存放标签 49 index = 0 50 for line in arrayOLines: 51 line = line.strip() # strip() 方法用于移除字符串头尾指定的字符(默认为空格或换行符)或字符序列。 52 listFromLine = line.split('\t') # split() 通过指定分隔符对字符串进行切片,组成列表 53 returnMat[index, :] = listFromLine[0:3] # 将当前列表的前3个值赋予returnMat的当前行 54 classLabelVector.append(int(listFromLine[-1])) # 将标签添加到classLabelVector中 55 index += 1 56 return returnMat, classLabelVector 57 58 59 def autoNum(dataSet): 60 minVals = dataSet.min(0) # A.min(0) : 返回A每一列最小值组成的一维数组; 61 maxVals = dataSet.max(0) # A.max(0):返回A每一列最大值组成的一维数组; 62 # https://blog.csdn.net/qq_41800366/article/details/86313052 63 ranges = maxVals - minVals 64 normDataSet = zeros(shape(dataSet)) 65 m = dataSet.shape[0] 66 normDataSet = dataSet - tile(minVals, (m,1)) 67 # tile将minVals的行数乘以m次重复,列数乘以1次重复,每一行都减掉minVals 68 normDataSet = normDataSet/tile(ranges,(m,1)) 69 # 每一行都除以ranges以是实现数据归一化 70 return normDataSet,ranges, minVals 71 72 73 def datingClassTest(): 74 hoRatio = 0.10 # 测试集比重 75 m = normMat.shape[0] 76 numTestVecs = int(m*hoRatio) # 测试集数量 77 errorCount = 0.0 78 for i in range(numTestVecs): 79 classfierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3) 80 print("the classifierResult came back with: %d,the real answer is: %d"%(classfierResult,datingLabels[i])) 81 if(classfierResult != datingLabels[i]):errorCount += 1.0 82 print("the total error rate is: %f"%(errorCount/float(numTestVecs))) 83 84 def classifyPerson(): 85 resultList = ['not at all', 'in small doses', 'in large doses'] 86 percentTats = float(input("percentage of time spent playing video games?")) 87 # 在 Python3.x 中 raw_input( ) 和 input( ) 进行了整合,去除了 raw_input( ),仅保留了 input( ) 函数, 88 # 其接收任意任性输入,将所有输入默认为字符串处理,并返回字符串类型。 89 ffMiles = float(input("frequent flier miles earned per year?")) 90 iceCream = float(input("liters of ice cream consumed per year?")) 91 inArr = array([ffMiles, percentTats, iceCream]) # 输入测试向量 92 classifierResult = classify0((inArr-minVals)/ranges,normMat,datingLabels,3) # 得到分类结果 93 print("You will probably like this person:",resultList[classifierResult-1]) 94 95 96 if __name__ == "__main__": 97 ''' 98 group,labels = createDataSet() 99 result = classify0([0,0],group,labels,3) 100 print(result) 101 ''' 102 datingDataMat, datingLabels = file2matrix("./datingTestSet2.txt") # 数据转换 103 # print(datingDataMat) 104 # print(datingLabels) 105 fig = plt.figure() 106 ax = fig.add_subplot() 107 ax.scatter(datingDataMat[:,0],datingDataMat[:,1],15.0*array(datingLabels),15.0*array(datingLabels)) 108 plt.show() 109 normMat, ranges, minVals = autoNum(datingDataMat) # 输入数据归一化 110 # print(normMat) 111 # print(ranges) 112 # print(minVals) 113 # datingClassTest() 114 classifyPerson() # 分类
运行结果:
1 percentage of time spent playing video games?10 2 frequent flier miles earned per year?4000 3 liters of ice cream consumed per year?1 4 You will probably like this person: in small doses
# python2中的iteritems()方法需改为items()