k近邻算法

介绍#

k近邻算法(KNN)属于监督学习的分类算法,通过测量不同特征值之间的距离进行分类,算法过程如下

  • 计算数据点与已知数据集中每个点的距离
  • 对距离从小到大进行排序
  • 选取前k个距离值
  • 确定前k个距离值所在类别的出现的概率
  • 将前k个点出现频率最高的类别作为当前数据的预测分类

主要代码如下

Copy
def classfiy(inData, dataSet, labels, k): dataSize = dataSet.shape[0] # 得到数组的行维度,即数据的个数 # 先通过tile将输入的数据扩展为与dataSet相同维度的数组,再通过距离公式计算距离 distance = (((tile(inData, (dataSize, 1)) - dataSet) ** 2).sum(axis=1)) ** 0.5 sortIndex = distance.argsort() # 返回数组值从小到大的索引值 classCount = {} for i in range(k): # 只对前k个计数 headLabel = labels[sortIndex[i]] classCount[headLabel] = classCount.get(headLabel, 0) + 1 # 统计前k个中出现标签的次数 # 对字典按照第二个值(即出现的次数)进行排序,用reverse指定从大到小排 sortCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) return sortCount[0][0] # 返回第一个的标签

其中距离计算,通过公式,如(x1,y1)(x2,y2)两点的距离d为d=(x1x2)2+(y1y2)2

用KNN识别数字图片中的数字#

只是个玩具程序

收集数据#

每个数字准备了10张图片,分别存在digit中的以各个数字命名的文件夹下

又为每个数据准备了5张图片,以同样的规则存在digit2的各个文件夹下

准备数据#

缩放图像
采用了pillow中的resize函数,同一将图像缩放为50*50
newImg = img.resize((50, 50))
二值化图像
开始想直接通过convet('1')直接将图像二值化,但出现了很多噪音
所以通过以下程序将图像二值化。其中230为设定的阀值,多次尝试,发现230效果较好

Copy
for i in range(rows): for j in range(cols): if (imgArray[i, j] <= 230): imgArray[i, j] = 0 else: imgArray[i, j] = 255

转化为一维向量
将读取的处理后的图片的像素值转化为一维向量

测试#

通过读取测试集中的数据,进行预测,和实际的类别比对,查看正确率

程序#

Copy
from PIL import Image from numpy import * import os import operator #缩放为相同大小 def toSame(img): newImg = img.resize((50, 50)) return newImg #二值化处理 def toBinarry(img): imgArray = array(img) rows, cols = imgArray.shape for i in range(rows): for j in range(cols): if (imgArray[i, j] <= 230): imgArray[i, j] = 0 else: imgArray[i, j] = 255 return imgArray #读取每个文件夹下的每张图片 def readImage(filePath): dataList = [] labels = [] for i in range(10): imagePath = filePath + '/' + str(i) files = os.listdir(imagePath) for j in files: labels.append(j.split('_')[0])#因为每张图片采用‘数字_第几张的命名方式’,所以通过下横线分割,取得前面的作为图片的分类标签 img = Image.open(imagePath + '/' + j).convert('L')#先灰度化处理 imgArray = toBinarry(toSame(img)) dataList.append(imgArray.ravel())#转变为一维后加入列表 dataSet = array(dataList) return dataSet, labels #分类算法 def classfiy(inData, dataSet, labels, k): dataSize = dataSet.shape[0] # 得到数组的行维度,即数据的个数 # 先通过tile将输入的数据扩展为与dataSet相同维度的数组,再通过距离公式计算距离 distance = (((tile(inData, (dataSize, 1)) - dataSet) ** 2).sum(axis=1)) ** 0.5 sortIndex = distance.argsort() # 返回数组值从小到大的索引值 classCount = {} for i in range(k): # 只对前k个计数 headLabel = labels[sortIndex[i]] classCount[headLabel] = classCount.get(headLabel, 0) + 1 # 统计前k个中出现标签的次数 # 对字典按照第二个值(即出现的次数)进行排序,用reverse指定从大到小排 sortCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) return sortCount[0][0] # 返回第一个的标签 # 进行测试 dataSet, labels = readImage('./digit') dataSet2, labels2 = readImage('./digit2') n = 0 for i in range(len(dataSet2)): predict = classfiy(dataSet2[i], dataSet, labels, 10) print(predict + ' ' + labels2[i]) if (predict == labels2[i]): n = n + 1 # 查看准确率 print(n / len(dataSet2))

运行结果#


发现准确率只有0.62

总结#

  • 准确率如此低,可能是数据不足,也可能对图像处理不好。在二值化时,效果其实并不完美。也可能需要对图像进行一些裁剪。在二值化时,本程序也只适合一些浅色底子的数字图片
  • 采用不同的k,预测的效果也是不同,也需要找到一个最佳的k

其它#

  • 在处理数据时,通常用到的归一化
Copy
def toNormal(dataSet): # 归一化 min = dataSet.min(0) max = dataSet.max(0) # 公式normal=(x-min)/(max-min) normalArray = (dataSet - tile(min, (dataSet.shape[0], 1))) / tile(max - min, (dataSet.shape[0], 1)) return normalArray
Copy
def toClear(imgArray): rows, cols = imgArray.shape for y in range(1, cols - 1): for x in range(1, rows - 1): count = 0 if imgArray[x, y - 1] == 255: # 上 count = count + 1 if imgArray[x, y + 1] == 255: # 下 count = count + 1 if imgArray[x - 1, y] == 255: # 左 count = count + 1 if imgArray[x + 1, y] == 255: # 右 count = count + 1 if imgArray[x - 1, y - 1] == 255: # 左上 count = count + 1 if imgArray[x - 1, y + 1] == 255: # 左下 count = count + 1 if imgArray[x + 1, y - 1] == 255: # 右上 count = count + 1 if imgArray[x + 1, y + 1] == 255: # 右下 count = count + 1 if count > 4: imgArray[x, y] = 255 return imgArray
posted @   启林O_o  阅读(266)  评论(0编辑  收藏  举报
编辑推荐:
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 一个奇形怪状的面试题:Bean中的CHM要不要加volatile?
· [.NET]调用本地 Deepseek 模型
· 一个费力不讨好的项目,让我损失了近一半的绩效!
阅读排行:
· PowerShell开发游戏 · 打蜜蜂
· 在鹅厂做java开发是什么体验
· 百万级群聊的设计实践
· WPF到Web的无缝过渡:英雄联盟客户端的OpenSilver迁移实战
· 永远不要相信用户的输入:从 SQL 注入攻防看输入验证的重要性
点击右上角即可分享
微信分享提示
CONTENTS