KNN算法原理实现

import numpy as np
import matplotlib.pyplot as plt
import operator

# 给出训练数据以及对应的类别
def createDataSet():
    group = np.array([[1.0,2.0],[1.2,0.1],[0.1,1.4],[0.3,3.5],[1.1,1.0],[0.5,1.5]])
    labels = np.array(['A','A','B','B','A','B'])
    return group,labels


def kNN_classify(k,dis,X_train,x_train,Y_test):
    assert dis == 'E' or dis == 'M', 'dis must E or M,E代表欧式距离,M代表曼哈顿距离'
    num_test = Y_test.shape[0] # 测试样本的数量
    labellist = []
    # 欧氏距离
    if (dis == 'E'):
        for i in range(num_test):
            distances = np.sqrt(np.sum(((X_train - np.tile(Y_test[i],(X_train.shape[0], 1)))**2),axis=1))
            nearest_k = np.argsort(distances)  #距离由小到大进行排序,并返回index值
            topK = nearest_k[:k]  #选取前 k个距离
            classCount = {}
            for i in topK:
                #统计每个类别的个数
                classCount[x_train[i]] = classCount.get(x_train[i],0) + 1
                sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
                labellist.append(sortedClassCount[0][0])
        return np.array(labellist)
    
    # 曼哈顿距离
    elif (dis == 'M'):
        for i in range(num_test):
            distances = np.sum(abs(X_train - np.tile(Y_test[i],(X_train.shape[0], 1))), axis=1)
            nearest_k = np.argsort(distances)  #距离由小到大进行排序,并返回index值
            topK = nearest_k[:k]  #选取前 k个距离
            classCount = {}
            for i in topK:
                #统计每个类别的个数
                classCount[x_train[i]] = classCount.get(x_train[i],0) + 1
                sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
                labellist.append(sortedClassCount[0][0])
        return np.array(labellist)


if __name__ == '__main__':
    group, labels = createDataSet()
    y_test = np.array([[1.0,2.1], [0.2,1.5]])
    #对于类别为A的数据集我们使用红色*表示
    plt.scatter(group[labels=='A',0],group[labels=='A',1],color = 'r', marker='*')
    #对于类别为B的数据集我们使用绿色+表示
    plt.scatter(group[labels=='B',0],group[labels=='B',1],color = 'g', marker='+')
    plt.scatter(y_test[:,0],y_test[:,1],color = 'black', s =10)
    plt.show()

    y_test_pred = kNN_classify(1, 'M', group, labels, y_test)
    print(y_test_pred)

posted @   YI颗白菜  阅读(5)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· 阿里巴巴 QwQ-32B真的超越了 DeepSeek R-1吗?
· 如何调用 DeepSeek 的自然语言处理 API 接口并集成到在线客服系统
· 【译】Visual Studio 中新的强大生产力特性
· 2025年我用 Compose 写了一个 Todo App
点击右上角即可分享
微信分享提示