kNN(k近邻)算法代码实现
目标:预测未知数据(或测试数据)X的分类y
批量kNN算法
1.输入一个待预测的X(一维或多维)给训练数据集,计算出训练集X_train中的每一个样本与其的距离
2.找到前k个距离该数据最近的样本-->所属的分类y_train
3.将前k近的样本进行统计,哪个分类多,则我们将x分类为哪个分类
# 准备阶段: import numpy as np # import matplotlib.pyplot as plt raw_data_X = [[3.393533211, 2.331273381], [3.110073483, 1.781539638], [1.343808831, 3.368360954], [3.582294042, 4.679179110], [2.280362439, 2.866990263], [7.423436942, 4.696522875], [5.745051997, 3.533989803], [9.172168622, 2.511101045], [7.792783481, 3.424088941], [7.939820817, 0.791637231] ] raw_data_y = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] X_train = np.array(raw_data_X) y_train = np.array(raw_data_y) x = np.array([8.093607318, 3.365731514])
核心代码:
目标:预测未知数据(或测试数据)X的分类y 批量kNN算法 1.输入一个待预测的X(一维或多维)给训练数据集,计算出训练集X_train中的每一个样本与其的距离 2.找到前k个距离该数据最近的样本-->所属的分类y_train 3.将前k近的样本进行统计,哪个分类多,则我们将x分类为哪个分类 from math import sqrt from collections import Counter # 已知X_train,y_train # 预测x的分类 def predict(x, k=5): # 计算训练集每个样本与x的距离 distances = [sqrt(np.sum((x-x_train)**2)) for x_train in X_train] # 这里用了numpy的fancy方法,np.sum((x-x_train)**2) # 获得距离对应的索引,可以通过这些索引找到其所属分类y_train nearest = np.argsort(distances) # 得到前k近的分类y topK_y = [y_train[neighbor] for neighbor in nearest[:k]] # 投票的方式,得到一个字典,key是分类,value数个数 votes = Counter(topK_y) # 取出得票第一名的分类 return votes.most_common(1)[0][0] # 得到y_predict predict(x, k=6)
面向对象的方式,模仿sklearn中的方法实现kNN算法:
import numpy as np from math import sqrt from collections import Counter class kNN_classify: def __init__(self, n_neighbor=5): self.k = n_neighbor self._X_train = None self._y_train = None def fit(self, X_train, y_train): self._X_train = X_train self._y_train = y_train return self def predict(self, X): '''接收多维数据,返回y_predict也是多维的''' y_predict = [self._predict(x) for x in X] # return y_predict return np.array(y_predict) # 返回array的格式 def _predict(self, x): '''接收一个待预测的x,返回y_predict''' distances = [sqrt(np.sum((x-x_train)**2)) for x_train in self._X_train] nearest = np.argsort(distances) topK_y = [self._y_train[neighbor] for neighbor in nearest[:self.k]] votes = Counter(topK_y) return votes.most_common(1)[0][0] def __repr__(self): return 'kNN_clf(k=%d)' % self.k
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· winform 绘制太阳,地球,月球 运作规律
· AI与.NET技术实操系列(五):向量存储与相似性搜索在 .NET 中的实现
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)