【Python学习】基于 KNN 模型的葡萄酒种类预测附代码
算法原理
K最近邻(KNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。
给定测试样本,基于某种距离度量找出训练集中与其最靠近的K个训练样本,然后基于这 K个"邻居"的信息来进行预测。
KNN 算法的核心思想是如果一个样本在特征空间中的 K 个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。
具体步骤
给定训练样本集和一组类属性 ,对样本进行分类,KNN 算法的基本步骤为:
(1)先求出 t 与 S 中所有训练样本的距离 ,并对所有求出的值递增排序;
(2)选取与待测样本距离最小的 K 个样本,组成集合 N;
(3)统计 N 中 K 个样本所属类别现的频率;
(4)频率最高的类别作为待测样本的类别。
举例说明
如果没看懂上面在说什么,没关系,举个例子这样能更好的理解一下,这里采用的是欧氏距离
假设 测试集有2和11,其中K=4;对测试集分别计算其与训练集的欧氏距离,得到如下结果
对测试集与训练集的距离进行排序,选取距离最小的前四(k=4)个
在测试集2,距离最小的四个(k=4)训练集数据中,属于类别1的有3个,属于类别2的有1个,所以测试集2的类别为1
在测试集11,距离最小的四个(k=4)训练集数据中,属于类别1的有1个,属于类别2的有3个,所以测试集11的类别为2
相信通过这个例子,大家能够很好的理解KNN算法的具体步骤了。
好了,话不多说,下面还是上代码。
KNN算法代码
"""
@Date 2022.5.14
@Author Harper
"""
import operator
import random
import numpy as np
import xlrd
from matplotlib import pyplot as plt
# 葡萄酒数据导入
file_location = r'E:\3.Python\4.实验代码\Case_2\Wine.xls'
file = xlrd.open_workbook(file_location) # excel中全部数据
sheet1 = file.sheet_by_index(0) # sheet1数据
data_wine = np.mat(
[[sheet1.cell_value(r, c) for c in range(1, sheet1.ncols)] for r in range(sheet1.nrows)]) # 样本特征数据,数组
data_labels = np.mat([sheet1.cell_value(r, 0) for r in range(sheet1.nrows)])
# 归一化数据
def MaxMinNorm(array):
max_cols = array.max(axis=0)
min_cols = array.min(axis=0)
data_rows = 178
data_cols = 13
t = np.zeros([data_rows, data_cols])
# 归一化公式num = (num0-min)/(max-min)
for c in range(data_cols):
for r in range(data_rows):
t[r, c] = (array[r, c] - min_cols[0, c]) / (max_cols[0, c] - min_cols[0, c])
return t
# 随机选取训练集 测试集索引 √
def GetRandomIndex():
"""
:return: tra 训练集索引
te 测试集索引
"""
def randomIndex(a, b, x):
# 索引范围为[a,b),随机选x个不重复
index = random.sample(range(a, b), x)
return index
tra = []
tra.extend(np.array(randomIndex(0, 59, 53)))
tra.extend(np.array(randomIndex(59, 130, 64)))
tra.extend(np.array(randomIndex(130, 178, 43)))
te = np.delete(np.arange(0, 178), tra)
# print(te)
return tra, te
# 高维距离计算
def Distance(veca, vecb, length):
"""
:param veca: 向量a
:param vecb: 向量b
:param length: 维度
:return: 欧氏距离
"""
dis = 0
for x in range(length):
dis += pow((veca[x] - vecb[x]), 2)
return np.sqrt(dis)
# print('-' * 30 + 'Distance_test' + '-' * 30)
# print(Distance(data[0], data[1], 13))
def GetNeighbors(dataSet, train_index, test_index):
"""
:param dataSet: 数据集
:param train_index: 训练集索引
:param test_index: 测试集索引(单个数)
:return: 距离
"""
distance = []
for x in range(len(train_index)):
dis = Distance(dataSet[test_index], dataSet[train_index[x]], 13)
distance.append((train_index[x], dis)) # 添加对应索引 距离
# operator.itemgetter(1) 获取对象的第1个域的值
# key = operator.itemgetter(1) 就是以距离进行排序
distance.sort(key=operator.itemgetter(1))
print(distance)
return distance
def GetClassify(dis, k):
neighbors = []
for x in range(k):
print(dis[x][0])
if dis[x][0] < 59: # 1类
neighbors.append(1)
elif dis[x][0] < 130: # 2类
neighbors.append(2)
elif dis[x][0] < 178:
neighbors.append(3)
print(neighbors)
# 选众数
classify = np.argmax(np.bincount(neighbors))
return classify
# print('=' * 60)
# print(test[0])
# print('=' * 60)
# cla = GetClassify(data, train, test[0], 34)
# print(cla)
# Get_class(neighbor)
def GetAccuracy(real, predict):
yes = 0
for i in range(len(real)):
if real[i] == predict[i]:
yes += 1
return (yes / float(len(real))) * 100.0
def draw(dataSet, test_index, Predict_class, accuracy, k):
pre1_x, pre2_x, pre3_x, pre1_y, pre2_y, pre3_y = list(), list(), list(), list(), list(), list()
real1_x, real1_y, real2_x, real2_y, real3_x, real3_y = list(), list(), list(), list(), list(), list()
for i in range(len(test_index)):
if i < 6:
real1_x.append(dataSet[test_index[i], 0])
real1_y.append(dataSet[test_index[i], 1])
elif 6 <= i < 13:
real2_x.append(dataSet[test_index[i], 0])
# print(f"2_y:{dataSet[test_index[i], 1]}")
real2_y.append(dataSet[test_index[i], 1])
elif 13 <= i:
real3_x.append(dataSet[test_index[i], 0])
# print(f"3_y:{dataSet[test_index[i], 1]}")
real3_y.append(dataSet[test_index[i], 1])
if Predict_class[i] == 1:
pre1_x.append(dataSet[test_index[i], 0])
pre1_y.append(dataSet[test_index[i], 1])
elif Predict_class[i] == 2:
pre2_x.append(dataSet[test_index[i], 0])
pre2_y.append(dataSet[test_index[i], 1])
elif Predict_class[i] == 3:
pre3_x.append(dataSet[test_index[i], 0])
pre3_y.append(dataSet[test_index[i], 1])
# # test 7 28 31 34 38 39
# # [14.06, 13.87, 13.58, 13.51, 13.07, 14.22]
print('=' * 30 + '1' + '=' * 28)
print(real1_y)
print(real1_x)
print('=' * 30 + '2' + '=' * 28)
print(real2_y)
print(real2_x)
print('=' * 30 + '3' + '=' * 28)
print(real3_y)
print(real3_x)
plt.figure(1) # 样本
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.title(f'样本分类')
plt.scatter(real1_x[:], real1_y[:], color='red', marker='+', label='1')
plt.scatter(real2_x[:], real2_y[:], color='blue', marker='+', label='2')
plt.scatter(real3_x[:], real3_y[:], color='green', marker='+', label='3')
plt.legend(loc=2) # 左上角
plt.figure(2) # 样本
plt.title(f"KNN分类\n"
f"准确率:{format(accuracy,'.2f')}%,K={k}")
plt.scatter(pre1_x[:], pre1_y[:], color='red', marker='+', label='1')
plt.scatter(pre2_x[:], pre2_y[:], color='blue', marker='+', label='2')
plt.scatter(pre3_x[:], pre3_y[:], color='green', marker='+', label='3')
plt.legend(loc=2)
plt.show()
def main():
# 归一化数据
predict_class = []
real_class = []
K = 6
data = MaxMinNorm(data_wine)
train, test = GetRandomIndex()
print(test)
for i in test:
distance = GetNeighbors(data, train, i)
predict_class.append(GetClassify(distance, K))
if i < 59:
real_class.append(1)
elif 59 <= i < 130:
real_class.append(2)
elif 130 <= i < 178:
real_class.append(3)
print('=' * 30 + 'real_class' + '=' * 30)
print(real_class)
print('=' * 30 + 'predict_class' + '=' * 28)
print(predict_class)
acc = GetAccuracy(real_class, predict_class)
print(f"Accuracy:{acc}%")
draw(data_wine, test, predict_class, acc, K)
main()
# data = MaxMinNorm(data_wine)
# train, test = GetRandomIndex()
# print(test)
# print(data[test[:6],0])
# draw(data,test,1)
结果
真实分类
KNN分类
以上就是全部内容,希望对你有帮助。
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· CSnakes vs Python.NET:高效嵌入与灵活互通的跨语言方案对比
· DeepSeek “源神”启动!「GitHub 热点速览」
· 我与微信审核的“相爱相杀”看个人小程序副业
· Plotly.NET 一个为 .NET 打造的强大开源交互式图表库
· 上周热点回顾(2.17-2.23)