总访问量: PV

DataScience && DataMining && BigData

Python实现KNN算法及手写程序识别

1.Python实现KNN算法

输入:inX:与现有数据集(1xN)进行比较的向量
   dataSet:已知向量的大小m数据集(NxM)
   个标签:数据集标签(1xM矢量)
   k:用于比较的邻居数(应为奇数)
输出:最受欢迎的类标签(归类问题)

  1 # -*- coding: utf-8 -*-
  2 """
  3 Created on Sun Apr 16 23:01:54 2017
  4 
  5 @author: SimonsZhao
  6 """ 10 kNN: k Nearest Neighbors
 12 Input:      inX: vector to compare to existing dataset (1xN)
 13             dataSet: size m data set of known vectors (NxM)
 14             labels: data set labels (1xM vector)
 15             k: number of neighbors to use for comparison (should be an odd number)
 17 Output:     the most popular class label 20 '''
 21 from numpy import *
 22 import operator
 23 from os import listdir
 24 
 25 def classify0(inX, dataSet, labels, k):
 26     dataSetSize = dataSet.shape[0]
 27     diffMat = tile(inX, (dataSetSize,1)) - dataSet
 28     sqDiffMat = diffMat**2
 29     sqDistances = sqDiffMat.sum(axis=1)
 30     distances = sqDistances**0.5
 31     sortedDistIndicies = distances.argsort()     
 32     classCount={}          
 33     for i in range(k):
 34         voteIlabel = labels[sortedDistIndicies[i]]
 35         classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
 36     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
 37     return sortedClassCount[0][0]
 38 
 39 def createDataSet():
 40     group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
 41     labels = ['A','A','B','B']
 42     return group, labels
 43 
 44 def file2matrix(filename):
 45     fr = open(filename)
 46     numberOfLines = len(fr.readlines())         #get the number of lines in the file
 47     returnMat = zeros((numberOfLines,3))        #prepare matrix to return
 48     classLabelVector = []                       #prepare labels return   
 49     fr = open(filename)
 50     index = 0
 51     for line in fr.readlines():
 52         line = line.strip()
 53         listFromLine = line.split('\t')
 54         returnMat[index,:] = listFromLine[0:3]
 55         classLabelVector.append(int(listFromLine[-1]))
 56         index += 1
 57     return returnMat,classLabelVector
 58     
 59 def autoNorm(dataSet):
 60     minVals = dataSet.min(0)
 61     maxVals = dataSet.max(0)
 62     ranges = maxVals - minVals
 63     normDataSet = zeros(shape(dataSet))
 64     m = dataSet.shape[0]
 65     normDataSet = dataSet - tile(minVals, (m,1))
 66     normDataSet = normDataSet/tile(ranges, (m,1))   #element wise divide
 67     return normDataSet, ranges, minVals
 68    
 69 def datingClassTest():
 70     hoRatio = 0.50      #hold out 10%
 71     datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')       #load data setfrom file
 72     normMat, ranges, minVals = autoNorm(datingDataMat)
 73     m = normMat.shape[0]
 74     numTestVecs = int(m*hoRatio)
 75     errorCount = 0.0
 76     for i in range(numTestVecs):
 77         classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
 78         print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])
 79         if (classifierResult != datingLabels[i]): errorCount += 1.0
 80     print "the total error rate is: %f" % (errorCount/float(numTestVecs))
 81     print errorCount
 82     
 83 def img2vector(filename):
 84     returnVect = zeros((1,1024))
 85     fr = open(filename)
 86     for i in range(32):
 87         lineStr = fr.readline()
 88         for j in range(32):
 89             returnVect[0,32*i+j] = int(lineStr[j])
 90     return returnVect
 91 
 92 def handwritingClassTest():
 93     hwLabels = []
 94     trainingFileList = listdir('trainingDigits')           #load the training set
 95     m = len(trainingFileList)
 96     trainingMat = zeros((m,1024))
 97     for i in range(m):
 98         fileNameStr = trainingFileList[i]
 99         fileStr = fileNameStr.split('.')[0]     #take off .txt
100         classNumStr = int(fileStr.split('_')[0])
101         hwLabels.append(classNumStr)
102         trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
103     testFileList = listdir('testDigits')        #iterate through the test set
104     errorCount = 0.0
105     mTest = len(testFileList)
106     for i in range(mTest):
107         fileNameStr = testFileList[i]
108         fileStr = fileNameStr.split('.')[0]     #take off .txt
109         classNumStr = int(fileStr.split('_')[0])
110         vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
111         classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
112         print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)
113         if (classifierResult != classNumStr): errorCount += 1.0
114     print "\nthe total number of errors is: %d" % errorCount
115     print "\nthe total error rate is: %f" % (errorCount/float(mTest))

2.数据集(测试集合训练集)

 

3.KNN测试结果

 

posted @ 2017-05-03 09:33  CJZhaoSimons  阅读(1083)  评论(1编辑  收藏  举报