Python简单实现KNN算法

__author__ = '糖衣豆豆'
from numpy import  *
from os import listdir
import operator
#从列方向扩展
#tile(a,(size,1))
#实现KNN算法,需要指定k,需要测试数据集,需要训练数据集,类别名(标签),
def knn(k,testdata,traindata,labels):
    #通过shape获得行数
    traindatasize=traindata.shape[0]
    #扩展testdata的维数,tile函数可以扩展testdata和traindata相同的行数,然后和traindata的向量相减计算测试机和训练集的差值
    dif=tile(testdata,(traindatasize,1))-traindata
    #计算差值的平方
    sqdif=dif**2
    #计算平方和,每一行的各列求和,axis=1每一行的各列求和
    sumsqdif=sqdif.sum(axis=1)
    #开方
    distance=sumsqdif**0.5
    #排序
    sortdistance=distance.argsort()
    #空字典
    count={}
    #选择距离最短的k
    for i in range(0,k):
        #获取类别,下标决定属于哪一类
        vote=labels[sortdistance[i]]
        #整理为一定格式,得到类别vote,每出现一次统计一次
        count[vote]=count.get(vote,0)+1
    #取出最多的类别,reverse=True表示降序
    sortcount=sorted(count.items(),key=operator.itemgetter(1),reverse=True)
    return sortcount[0][0]

#图片处理
#先将图片转为固定宽高,比如32*32,然后再转为文本
'''
from PIL import Image
im=Image.open("~/Downloads/123.png")
fh=open("~/Downloads/123_txt","a")
width=im.size[0]
height=im.size[1]
#k=im.getpixel((1,9))
#print(k)
for i in range(0,width):
    for j in range(0,height):
        cl=im.getpixel((i,j))
        clall=cl[0]+cl[1]+cl[2]
        if(clall==0):
            #黑色
            fh.write("1")
        else:
            fh.write("0")
    fh.write("\n")
fh.close()
'''

#加载数据
#将数据转为数组
def datatoarray(fname):
    arr=[]
    fh=open(fname)
    #图片是32*32的横轴每次读取32
    for i in range(0,32):
        thisline=fh.readline()
        #读每一行
        for j in range(0,32):
            #读入到数组里
            arr.append(int(thisline[j]))
    return arr
arr1=datatoarray("~/coding/python/data/testandtraindata/testdata/0_74.txt")
#print(arr1)
#建立一个函数,取文件的前缀
def seplabel(fname):
    filestr=fname.split(".")[0]
    label=int(filestr.split("_")[0])
    return label
#建立训练数据
def traindata():
    #存储类别
    labels=[]
    #得到训练目录下所有的文件
    trainfile=listdir("~/coding/python/data/testandtraindata/traindata")
    #取当前文件有多少个
    num=len(trainfile)
    #生成一个多少行多少列的向量,行的长度应该是32*32=1024(列),每一行存储一个文件
    #用一个数组存储所有训练数据,行:文件总数,列:1024
    trainarr=zeros((num,1024))
    #第一层循环文件
    for i in range(0,num):
        thisfname=trainfile[i]
        #调用seplabel函数
        thislabel=seplabel(thisfname)
        #存到数组里
        labels.append(thislabel)
        #调用datatoarray函数,i,:处理重复读取
        trainarr[i,:]=datatoarray("~/coding/python/data/testandtraindata/traindata/"+thisfname)
    return trainarr,labels
#用测试数据条用KNN算法去测试,看是否能够准确识别
def datatest():
    trainarr,labels=traindata()
    testlist=listdir("~/coding/python/data/testandtraindata/testdata")
    tnum=len(testlist)
    for i in range(0,tnum):
        thistestfile=testlist[i]
        testarr=datatoarray("~/coding/python/data/testandtraindata/testdata/"+thistestfile)
        rknn=knn(3,testarr,trainarr,labels)
        print(rknn)
#a=datatest()
#print(a)
#抽某一个文件测试文件出来进行验证
trainarr,labels=traindata()
thistestfile="8_15.txt"
testarr=datatoarray("~/coding/python/data/testandtraindata/testdata/"+thistestfile)
rknn=knn(3,testarr,trainarr,labels)
print(rknn)

 

posted @ 2018-08-07 19:27  小武aj  阅读(523)  评论(0编辑  收藏  举报