Python实现KNN算法及手写程序识别
1.Python实现KNN算法
输入:inX:与现有数据集(1xN)进行比较的向量
dataSet:已知向量的大小m数据集(NxM)
个标签:数据集标签(1xM矢量)
k:用于比较的邻居数(应为奇数)
输出:最受欢迎的类标签(归类问题)
1 # -*- coding: utf-8 -*- 2 """ 3 Created on Sun Apr 16 23:01:54 2017 4 5 @author: SimonsZhao 6 """ 10 kNN: k Nearest Neighbors 12 Input: inX: vector to compare to existing dataset (1xN) 13 dataSet: size m data set of known vectors (NxM) 14 labels: data set labels (1xM vector) 15 k: number of neighbors to use for comparison (should be an odd number) 17 Output: the most popular class label 20 ''' 21 from numpy import * 22 import operator 23 from os import listdir 24 25 def classify0(inX, dataSet, labels, k): 26 dataSetSize = dataSet.shape[0] 27 diffMat = tile(inX, (dataSetSize,1)) - dataSet 28 sqDiffMat = diffMat**2 29 sqDistances = sqDiffMat.sum(axis=1) 30 distances = sqDistances**0.5 31 sortedDistIndicies = distances.argsort() 32 classCount={} 33 for i in range(k): 34 voteIlabel = labels[sortedDistIndicies[i]] 35 classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 36 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) 37 return sortedClassCount[0][0] 38 39 def createDataSet(): 40 group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]]) 41 labels = ['A','A','B','B'] 42 return group, labels 43 44 def file2matrix(filename): 45 fr = open(filename) 46 numberOfLines = len(fr.readlines()) #get the number of lines in the file 47 returnMat = zeros((numberOfLines,3)) #prepare matrix to return 48 classLabelVector = [] #prepare labels return 49 fr = open(filename) 50 index = 0 51 for line in fr.readlines(): 52 line = line.strip() 53 listFromLine = line.split('\t') 54 returnMat[index,:] = listFromLine[0:3] 55 classLabelVector.append(int(listFromLine[-1])) 56 index += 1 57 return returnMat,classLabelVector 58 59 def autoNorm(dataSet): 60 minVals = dataSet.min(0) 61 maxVals = dataSet.max(0) 62 ranges = maxVals - minVals 63 normDataSet = zeros(shape(dataSet)) 64 m = dataSet.shape[0] 65 normDataSet = dataSet - tile(minVals, (m,1)) 66 normDataSet = normDataSet/tile(ranges, (m,1)) #element wise divide 67 return normDataSet, ranges, minVals 68 69 def datingClassTest(): 70 hoRatio = 0.50 #hold out 10% 71 datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') #load data setfrom file 72 normMat, ranges, minVals = autoNorm(datingDataMat) 73 m = normMat.shape[0] 74 numTestVecs = int(m*hoRatio) 75 errorCount = 0.0 76 for i in range(numTestVecs): 77 classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3) 78 print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]) 79 if (classifierResult != datingLabels[i]): errorCount += 1.0 80 print "the total error rate is: %f" % (errorCount/float(numTestVecs)) 81 print errorCount 82 83 def img2vector(filename): 84 returnVect = zeros((1,1024)) 85 fr = open(filename) 86 for i in range(32): 87 lineStr = fr.readline() 88 for j in range(32): 89 returnVect[0,32*i+j] = int(lineStr[j]) 90 return returnVect 91 92 def handwritingClassTest(): 93 hwLabels = [] 94 trainingFileList = listdir('trainingDigits') #load the training set 95 m = len(trainingFileList) 96 trainingMat = zeros((m,1024)) 97 for i in range(m): 98 fileNameStr = trainingFileList[i] 99 fileStr = fileNameStr.split('.')[0] #take off .txt 100 classNumStr = int(fileStr.split('_')[0]) 101 hwLabels.append(classNumStr) 102 trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr) 103 testFileList = listdir('testDigits') #iterate through the test set 104 errorCount = 0.0 105 mTest = len(testFileList) 106 for i in range(mTest): 107 fileNameStr = testFileList[i] 108 fileStr = fileNameStr.split('.')[0] #take off .txt 109 classNumStr = int(fileStr.split('_')[0]) 110 vectorUnderTest = img2vector('testDigits/%s' % fileNameStr) 111 classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) 112 print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr) 113 if (classifierResult != classNumStr): errorCount += 1.0 114 print "\nthe total number of errors is: %d" % errorCount 115 print "\nthe total error rate is: %f" % (errorCount/float(mTest))
2.数据集(测试集合训练集)
3.KNN测试结果
博客地址:http://www.cnblogs.com/jackchen-Net/