K近邻算法(k-nearest neighbor, kNN)

K近邻算法(K-nearest neighbor, KNN)

KNN是一种分类和回归方法。

  • KNN简介
  • KNN模型3要素
  • KNN优缺点
  • KNN应用
  • 参考文献

KNN简介

KNN思想

给定一个训练集 T={(x1,y1),(x2,y2),...,(xN,yN)} T = { ( x 1 , y 1 ) , ( x 2 , y 2 ) , . . . , ( x N , y N ) } ,对新输入的实例 x x ,在训练集中找到与实例 x 最近的k个实例,根据k个实例中大多数类所属的类作为实例 x x <script type="math/tex" id="MathJax-Element-4">x</script> 所属的类。

KNN算法

这里写图片描述

KNN模型3要素

K值得选择、距离度量方法选择、分类决策规则选择 

K值得选择
应用中,一般选择较小的k值,交叉验证可以选择最优的k值。
k值减小,模型变复杂,容易过拟合(原因:选择较小k值时,近似误差减小,估计误差增大)。
近似误差:即对现有训练集的训练误差,更关注“训练”。
估计误差:即对测试集的测试误差,更关注“测试”。
距离度量方法选择
欧氏距离
曼哈顿距离
切比雪夫距离 等等
分类决策规则选择
最常用的是,大多数原则:即由输入实例的k个近邻样本中大多数的类别确定输入实例的类。

KNN优缺点

优点
简单、精度高
缺点
计算时间、空间复杂度高

KNN应用

使用knn算法识别手写数字数据集,链接:https://pan.baidu.com/s/1rgiGBLTMiybCCSUnzR1lYw 密码:yse7

# -*-coding:utf-8-*-

from numpy import *
import operator
from os import listdir


def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]  # shape[0]读取矩阵第一维的长度
    diffMat = tile(inX, (dataSetSize, 1)) - dataSet  # numpy.tile(A,B)函数重复A, B次
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5
    #print(type(distances))
    sortedDistIndicies = distances.argsort()

    classCount = {}
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1

    sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)
    return sortedClassCount[0][0]


def img2vector(filename):
    returnVect = zeros((1, 1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0, 32*i + j] = int(lineStr[j])
    return returnVect


def handwritingClassTest():
    hwLabels = []
    trainingFileList = listdir('digits/trainingDigits')           # 加载训练集
    m = len(trainingFileList)
    trainingMat = zeros((m,1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]     # 提取文件名
        classNumStr = int(fileStr.split('_')[0])  # 提取类别标签
        hwLabels.append(classNumStr)
        trainingMat[i,:] = img2vector('digits/trainingDigits/%s' % fileNameStr)
    testFileList = listdir('digits/testDigits')        # 加载测试集
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('digits/testDigits/%s' % fileNameStr)
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
        print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)
        if (classifierResult != classNumStr): errorCount += 1.0
    print ("\nthe total number of errors is: %d" % errorCount)
    print ("\nthe total error rate is: %f" % (errorCount/float(mTest)))


if __name__ == '__main__':
    handwritingClassTest()

程序运行结果:
这里写图片描述

参考文献

[1]李航. 统计学习方法[M]. 清华大学出版社, 2012.
[2]Peter Harrington. 机器学习实战[M]. 人民邮电出版社, 2013.

posted @ 2022-11-30 18:30  风兮177  阅读(82)  评论(0编辑  收藏  举报