手写字识别
Python实现手写识别
#coding=utf-8 from numpy import * import operator from os import listdir #k近邻算法 def classify0(inX, dataSet, labels, k): dataSetSize = dataSet.shape[0] #返回dataset这个array的行数 diffMat = tile(inX, (dataSetSize,1)) - dataSet #tile(A,reps)将A补成reps规格 sqDiffMat = diffMat**2 #平方 sqDistances = sqDiffMat.sum(axis=1) #默认的axis=0 就是普通的相加 而当加入axis=1以后就是将一个矩阵的每一行向量相加 distances = sqDistances**0.5 #开方 sortedDistIndicies = distances.argsort() #argsort其实是返回array排序后的下标(或索引) classCount={} #新建一个字典 for i in range(k): voteIlabel = labels[sortedDistIndicies[i]] #依次查询cclassCount中是否有该key,有则将取出value再+1,没有则返回添加该key并置value为0,再+1 classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 #统计得到各个标签的个数 #按classCount字典的第2个元素(即类别出现的次数)从大到小排序,即获得得票最高的标签 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0] #从文本文件解析数据 def file2matrix(filename): fr = open(filename) numberOfLines = len(fr.readlines()) #get the number of lines in the file returnMat = zeros((numberOfLines,3)) #prepare matrix to return classLabelVector = [] #prepare labels return fr = open(filename) index = 0 for line in fr.readlines(): line = line.strip() #删除文本行line后的回车符 listFromLine = line.split('\t') #使用’\t’分割字符串str,返回一个列表 returnMat[index,:] = listFromLine[0:3] classLabelVector.append(int(listFromLine[-1])) index += 1 return returnMat,classLabelVector #归一化特征值 def autoNorm(dataSet): minVals = dataSet.min(0) maxVals = dataSet.max(0) ranges = maxVals - minVals normDataSet = zeros(shape(dataSet)) #shape数组或矩阵的各个维的大小 m = dataSet.shape[0] #返回dataset这个array的行数 normDataSet = dataSet - tile(minVals, (m,1)) normDataSet = normDataSet/tile(ranges, (m,1)) #element wise divide return normDataSet, ranges, minVals def datingClassTest(): hoRatio = 0.10 #hold out 10% datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') #load data setfrom file normMat, ranges, minVals = autoNorm(datingDataMat) m = normMat.shape[0] numTestVecs = int(m*hoRatio) errorCount = 0.0 for i in range(numTestVecs): classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3) print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]) if (classifierResult != datingLabels[i]): errorCount += 1.0 print "the total error rate is: %f" % (errorCount/float(numTestVecs)) print errorCount #将图像转换为向量函数 def img2vector(filename): returnVect = zeros((1,1024)) fr = open(filename) for i in range(32): lineStr = fr.readline() #读取文件对象fr的当前行,返回字符串 for j in range(32): returnVect[0,32*i+j] = int(lineStr[j]) return returnVect def handwritingClassTest(): hwLabels = [] trainingFileList = listdir('trainingDigits') #load the training set m = len(trainingFileList) trainingMat = zeros((m,1024)) for i in range(m): fileNameStr = trainingFileList[i] fileStr = fileNameStr.split('.')[0] #take off .txt classNumStr = int(fileStr.split('_')[0]) hwLabels.append(classNumStr) trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr) testFileList = listdir('testDigits') #iterate through the test set errorCount = 0.0 mTest = len(testFileList) for i in range(mTest): fileNameStr = testFileList[i] fileStr = fileNameStr.split('.')[0] #take off .txt classNumStr = int(fileStr.split('_')[0]) vectorUnderTest = img2vector('testDigits/%s' % fileNameStr) classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr) if (classifierResult != classNumStr): errorCount += 1.0 print "\nthe total number of errors is: %d" % errorCount print "\nthe total error rate is: %f" % (errorCount/float(mTest)) def main(): #datingClassTest() handwritingClassTest() main() # 函数名/属性 功能 # array() 创建一个数组 # shape 数组或矩阵的各个维的大小 # tile(A, reps) 将数组A,根据数组reps沿各个维度重复多次,构成一个新的数组。reps的数字从后往前分别对应A的第N个维度的重复次数。 # sum(arr,axis=1) 根据行列(轴),求和 # max(arr,axis=1) 根据行列(轴),求最大值 # min(arr,axis=1) 根据行列(轴),求最小值 # mean(arr,axis=1) 根据行列(轴),求平均值 # argsort() 得到矩阵中每个元素的排序序号 # dict.get(key,default) 获取字典中,一个给定的key对应的值。若key不存在,则返回默认值default。 # sorted(iterable[, key][, reverse]) 第一个参数是一个iterable,返回值是一个对iterable中元素进行排序后的列表(list)。 # open(filename) 返回一个文件对象 # fr.readlines() 读取文件对象fr中的所有行,返回数组 # fr.readline() 读取文件对象fr的当前行,返回字符串 # len(arr) 返回数组的长度 # zeros((n,m)) 创建一个n*m的矩阵,用0填充 # line.strip() 删除文本行line后的回车符 # str.spit(‘\t’) 使用’\t’分割字符串str,返回一个列表 # list[-1] 获取列表的最后一个元素 # vec.append(item) 在向量、列表vec后追加元素item # mat[index, :] 获取矩阵/数组的第index行的所有元素 # list[m:n] 获取列表索引m到n的元素的值 # plt.figure() 创建画布? # fig.add_subplot((m,n,x)) 把画布分割成m*n的区块,在第x块上绘图 # scatter() 绘制散点 # print 格式化输出 # raw_input(“prompt string”) 显示提示字符串,将用户的输入转换成string # input(“prompt string”) 会根据用户输入变换相应的类型,而且如果要输入字符和字符串的时候必须要用引号包起来 # range() range(1,5) #代表从1到5(不包含5); range(1,5,2) #代表从1到5,间隔2(不包含5); range(5) #代表从0到5(不包含5) # listdir(‘folder’) from os import listdir,获取给定文件夹下的文件名列表,不含文件路径