机器学习实战SKlearn之KNN手写数字识别
1 # -*- coding: UTF-8 -*- 2 import numpy as np 3 import operator 4 from os import listdir 5 from sklearn.neighbors import KNeighborsClassifier as kNN 6 7 """ 8 函数说明:将32x32的二进制图像转换为1x1024向量。 9 10 Parameters: 11 filename - 文件名 12 Returns: 13 returnVect - 返回二进制图像的1x1024向量 14 """ 15 def img2vector(filename): 16 #创建1x1024零向量,np.zeros((a,b)),a代表第一层括号(最外层)看向元素个数,b代表第二层括号内向量元素个数即内层 17 #注意这里zeros(1,1024)产生的是二维数组[[0,0,....]] 18 return_vect = np.zeros((1, 1024)) # 数组的索引都是从0开始,但是size/shape中都是实际数目。 19 #打开文件 20 fr = open(filename) 21 #按行读取 22 for i in range(32): 23 #读一行数据 24 lineStr = fr.readline() 25 #每一行的前32个元素依次添加到returnVect中 26 for j in range(32): #range(32):0,1,2,....31 27 return_vect[0,32*i+j] = int(lineStr[j]) #0:向量第一层内的第一个分向量 28 #返回转换后的1x1024向量 29 return return_vect 30 31 """ 32 函数说明:手写数字分类测试 33 """ 34 def handwriting_ClassTest(): 35 #测试集的Labels 36 hwLabels = [] 37 #返回trainingDigits目录下的文件名 38 trainingFileList = listdir('C:/Users/Administrator/Desktop/data/Ch02/digits/trainingDigits') 39 #返回文件夹下文件的个数 40 m = len(trainingFileList) 41 #初始化训练的Mat矩阵,测试集 42 trainingMat = np.zeros((m, 1024)) 43 #从文件名中解析出训练集的类别 44 for i in range(m): 45 #获得文件的名字 46 fileName_Str = trainingFileList[i] 47 #获得分类的数字 48 classNumber = int(fileName_Str.split('_')[0]) 49 #将获得的类别添加到hwLabels中 50 hwLabels.append(classNumber) 51 #将每一个文件的1x1024数据存储到trainingMat矩阵中 52 trainingMat[i,:] = img2vector('C:/Users/Administrator/Desktop/data/Ch02/digits/trainingDigits/%s' % (fileName_Str)) 53 #构建kNN分类器 54 neigh = kNN(n_neighbors = 3, algorithm = 'auto') 55 #拟合模型, trainingMat为测试矩阵,hwLabels为对应的标签 56 neigh.fit(trainingMat, hwLabels) 57 #返回testDigits目录下的文件列表 58 testFileList = listdir('C:/Users/Administrator/Desktop/data/Ch02/digits/testDigits') 59 #错误检测计数 60 errorCount = 0.0 61 #测试数据的数量 62 mTest = len(testFileList) 63 #从文件中解析出测试集的类别并进行分类测试 64 for i in range(mTest): 65 #获得文件的名字 66 fileName_Str = testFileList[i] 67 #获得分类的数字 68 classNumber = int(fileName_Str.split('_')[0]) 69 #获得测试集的1x1024向量,用于训练 70 vector_UnderTest = img2vector('C:/Users/Administrator/Desktop/data/Ch02/digits/testDigits/%s' % (fileName_Str)) 71 #获得预测结果 72 # classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) 73 classifierResult = neigh.predict(vector_UnderTest) 74 print("分类返回结果为%d\t真实结果为%d" % (classifierResult, classNumber)) 75 if(classifierResult != classNumber): 76 errorCount += 1.0 77 print("总共错了%d个数据\n错误率为%f%%" % (errorCount, errorCount/mTest * 100)) 78 79 80 """ 81 函数说明:main函数 82 """ 83 if __name__ == '__main__': 84 handwriting_ClassTest()