[3] python:使用KNN识别手写数字

数据是很重要的,机器学习实战的源代码提供了数据,这点是非常好的

将图像转化为文本,读进向量里,就可以使用前面写的分类器

from numpy import *
import operator 
from os import listdir
def classify0(inX, dataset,labels,k):
    # 计算输入数据和已有所有数据的距离
    dataSetSize=dataset.shape[0]
    diffMat=tile(inX,(dataSetSize,1))-dataset
    sqDiffMat=diffMat**2
    sqDistances=sqDiffMat.sum(axis=1) #没有axis参数表示全部相加,axis=0表示按列相加,axis=1表示按照行的方向相加
    distances=sqDistances**0.5
    
    #排序
    sortedDistIndex=distances.argsort()  #argsort将数据从小到大排列,并返回其索引值
    # 选择距离最小的k个点
    classCount={} #字典类型
    
    for i in range(k):
        votelabel=labels[sortedDistIndex[i]]
        classCount[votelabel]=classCount.get(votelabel,0)+1
    sortedClasscount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    
    return sortedClasscount[0][0] 



# 将图像转换为向量,32*32的二进制图像矩阵转换成1*1024的向量
def img2vector(filename):
    returnVector=zeros((1,1024))
    fr=open(filename)
    for i in range(32):
        lineStr=fr.readline()
        for j in range(32):
            returnVector[0,32*i+j]=int(lineStr[j])
    return returnVector


def handClassTest():
    hwLabels=[]
    trainingFileList=listdir('trainingDigits')
    m=len(trainingFileList)
    trainingMat=zeros((m,1024))
    # 从训练数据的文件的所有txt文件中读取出每条数据,记录在向量中
    for i in range(m):
        filenameStr=trainingFileList[i]
        fileStr=filenameStr.split('.')[0]
        classNumStr=int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:]=img2vector('trainingDigits/%s' %(filenameStr))
    testFileList=listdir('testDigits')
    errorCount=0.0
    mTest=len(testFileList)
    for j in range(mTest):
        filenameStr=testFileList[j]
        fileStr=filenameStr.split('.')[0]
        classNumStr=int(fileStr.split('_')[0])
        vectorUnderTest=img2vector('testDigits/%s' %(filenameStr))
        classifyResult=classify0(vectorUnderTest, trainingMat,hwLabels,3)
        print( "分类结果:%d,真实的类别:%d" %(classifyResult, classNumStr))
        if(classifyResult!=classNumStr): errorCount+=1
    print("错误率:%f" %(errorCount/float(mTest)))

理解程序没有什么问题,很多函数前面也都学习过了,这里学习一下读取一个文件夹里的所有文件名

使用listdir()需要从os模块导入

from os import listdir

os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。这个列表以字母顺序。 它不包括 '.' 和'..' 即使它在文件夹中。

只支持在 Unix, Windows 下使用。

具体学习一下KNN算法:

https://zhuanlan.zhihu.com/p/25994179

posted @ 2018-04-12 19:46  mzhourr  阅读(272)  评论(0编辑  收藏  举报