import numpy as np
import operator
import os

#KNN算法
def knn(k,testdata,traindata,labels):#(k,测试样本,训练集,分类)
    traindatasize=traindata.shape[0]#行数
    #测试样本和训练集样本数可能不一样,因此需要将测试集样本数扩展成和训练集一样多
    #从行方向扩展 tile(a,(size,1))
    dif=np.tile(testdata,(traindatasize,1))-traindata
    #计算距离
    sqdif=dif**2
    sumsqdif=sqdif.sum(axis=1)
    distance=sumsqdif**0.5

    sortdistance=distance.argsort()#从小到大排列,结果返回元素位置
    count={}
    for i in range(k):
        vote=labels[sortdistance[i]]
        #统计每一类列样本的数量
        count[vote]=count.get(vote,0)+1
    sortcount=sorted(count.items(),key=operator.itemgetter(1),reverse=True)
    #取包含样本数量最多的那一类别
    return sortcount[0][0]


#加载数据,将文件转化为数组形式
def datatoarray(filename):
    arr=[]
    fh=open(filename)
    for i in range(32):
        thisline=fh.readline()
        for j in range(32):
            arr.append(int(thisline[j]))
    return arr

#获取文件的lable
def get_labels(filename):
    label=int(filename.split('_')[0])
    return label

#建立训练数据
def train_data():
    labels=[]
    trainlist=os.listdir('traindata/')
    num=len(trainlist)
    #长度1024(列),每一行存储一个文件
    #用一个数组存储所有训练数据,行:文件总数,列:1024
    trainarr=np.zeros((num,1024))
    for i in range(num):
        thisfile=trainlist[i]
        labels.append(get_labels(thisfile))
        trainarr[i,:]=datatoarray("traindata/"+thisfile)
    return trainarr,labels

#用测试数据调用KNN算法进行测试
def datatest():
    a=[]#准确结果
    b=[]#预测结果
    traindata,labels=train_data()
    testlist=os.listdir('testdata/')
    fh=open('result_knn.csv','a')
    for test in testlist:
        testfile='testdata/'+test
        testdata=datatoarray(testfile)
        result=knn(3,testdata,traindata,labels)
        #将预测结果存在文本中
        fh.write(test+'-----------'+str(result)+'\n')
        a.append(int(test.split('_')[0]))
        b.append(int(result))
    fh.close()
    return a,b

if __name__=='__main__':
    a,b=datatest()
    num=0
    for i in range(len(a)):
        if(a[i]==b[i]):
            num+=1
        else:
            print("预测失误:",a[i],"预测为",b[i])
    print("测试样本数为:",len(a))
    print("预测成功数为:",num)
    print("模型准确率为:",num/len(a))