【机器学习】手写数字识别算法
1.数据准备
样本数据获取忽略,实际上就是将32*32的图片上数字格式化成一个向量,如下:
本demo所有样本数据都是基于这种格式的
训练数据:将图片数据转成1*1024的数组,作为一个训练数据。
训练数据集:https://github.com/zimuqi/machine_Learning/tree/master/ch02/trainingDigits
测试数据集:https://github.com/zimuqi/machine_Learning/tree/master/ch02/testDigits
样本的文件名格式为:真实值_xxx.txt
转换代码:
1 def img2vector(filename): 2 returnVect=zeros((1,1024)) 3 fr=open(filename) 4 for i in range(32): 5 lineStr=fr.readline() 6 for j in range(32): 7 returnVect[0,32*i+j]=int(lineStr[j]) 8 return returnVect
2.测试算法
1 def handwritingClassTest(): 2 hwLabels=[] # 训练样本的标签数组 3 traningFileList=listdir("trainingDigits") # 获取所有的训练样本目录下的文件名 4 m=len(traningFileList) 5 traningMat=zeros((m,1024)) # 初始化训练样本数列 6 7 for i in range(m): 8 fileNameStr=traningFileList[i] # 获取文件名 9 fileStr=fileNameStr.split(".")[0] 10 clasNumStr=int(fileStr.split("_")[0]) # 获取样本的实际值 放入标签数组 11 hwLabels.append(clasNumStr) 12 traningMat[i,:]=img2vector("trainingDigits/{}".format(fileNameStr)) # 将样本转化成1*1024的行放入训练样本数列 13 14 testFileList=listdir("testDigits") # 测试样本目录 15 error=0 16 mtest=len(testFileList) 17 for i in range(mtest): 18 fileNameStr=testFileList[i] 19 fileStr=fileNameStr.split(".")[0] 20 clasNumStr=int(fileStr.split("_")[0]) 21 testMat=img2vector("testDigits/{}".format(fileNameStr)) 22 res=classify(testMat,traningMat,hwLabels,3) # 使用分类器分类 23 print "came bank with:{} the real anwser is:{}".format(clasNumStr,res) 24 if clasNumStr!=res: # 对比与真实的结果 计算错误率 25 error+=1 26 27 print "total:{}".format(mtest) 28 print "error:{}".format(error) 29 print "error:{}".format(float(error/mtest))
这个案例中 算法的识别率为:98.84%
classify是分类器 上上一篇文章中有写到,具体了解可以点击这里