《机器学习实战》菜鸟学习笔记(三)kNN手写识别系统

目的:利用kNN识别数字0-9

材料:32*32的数字方阵(保存形式是文本文件)

#-*-coding:utf-8-*-
from numpy import *

def img2vector(filename):
    #生成一个1*1024的array(zeros是numpy的函数,至于array与list区别这里就不多介绍了)
    returnVect = zeros((1,1024))
    #使用open函数打开一个文本文件
    fr = open(filename)
    #循环读取文件内容
    for i in range(32):
        #读取一行,返回字符串
        linestr = fr.readline()
        for j in range(32):
            #读取字符串0 或者 1
            returnVect[0,32*i+j] = int(linestr[j])
    #返回这个array
    return returnVect

这个程序很清晰,不做什么解释了。再看一下分类器是怎么实现的:

 

#定义测试代码
def handwringClassTest():
    #定义一个list,用于记录分类
    hwLabels = []
    #前面的Python os.listdir 可以列出 dir 里面的所有文件和目录,但不包括子目录中的内容。
    #os.walk 可以遍历下面的所有目录,包括子目录。
    trainingFileList = listdir('trainingDigits')
    #求出文件的长度
    m = len(trainningFileList)
    #生成m*1024的array,每个文件分配1024个0
    trainingMat = zeros((m,1024))
    #循环,对每一个file
    for i in range(m):
        #当前文件
        fileNameStr = trainingFileList[i]
        #理解这段代码要知道文件的命名方式,这里是这样命名的9_45.txt,9表示分类,45表示第45个。
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        #调用img2vector,将原文件写入trainingMat
        trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
    #找到testDigits中的文件
    testFileList = listdir('testDigits')
    #计算误差
    errorCount = 0.0
    #多少个文件
    mTest = len(testFileList)
    #遍历test文件
    for i inrange(mTest):
        #test文件
        fileNameStr = testFileList[i]
        #分类
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        #转换成1*1024
        vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
        # 调用knn分类
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
        #输出
        print "the classifier came back with:%d, the real anwer is : %d" % (classifierResult, classNumStr)
        #计算误差
        if (classifierResult != classNumStr): errorCount += 1.0
    print "\n the total numbe of error is: %d" % errorCount 
    print "\nthe total error rate is: %f" % (errorCount/flaot(mTest))

总结

kNN是一种最简单最有效的算法。但是kNN必须保留所有的数据集,如果训练数据集的很大,必须使用大量的存储空间,此外,需要对每一个数据计算距离,非常耗时。另外,它无法给出任何数据的基础结构信息(目前我还不能理解这句话,待更新。。。)。

posted @ 2014-10-03 23:29  程序员阿力  阅读(3705)  评论(0编辑  收藏  举报