ML实战:手动实现LVQ算法
- 本次数据集使用sklearn的make_blobs函数随机生成
代码实现
LVQ类
import numpy as np
np.set_printoptions(suppress=True)
class LVQ:
def __init__(self,x,y,q,qtag):
'''
:param x: 数据集
:param y: 数据标签
:param q: 簇个数
:param qtag: 簇标签
:param p: 原型向量
'''
self.x=x
self.y=y
n=len(self.x[0])
self.p=np.zeros((q,n))
self.qtag=qtag
self.q=q
for i in range(q):
index=np.argwhere(y==qtag[i]).reshape(1,-1)[0]
self.p[i]=x[np.random.choice(index)]
def find_pi(self,j):
#对于样本x选择最近的pi
temp = self.x[j] - self.p
dis = np.linalg.norm(temp, axis=1, keepdims=True)
return np.argmin(dis,axis=0)[0]
def single_iter(self,alpha):
#单次迭代函数
j=np.random.randint(0,len(self.x))
i=self.find_pi(j)
if self.qtag[i]==self.y[j]:
self.p[i]=self.p[i]+alpha*(self.x[j]-self.p[i])
else:
self.p[i]=self.p[i]-alpha*(self.x[j]-self.p[i])
def fit(self,alpha=0.3,iter_count=500):
#参数拟合
for i in range(iter_count):
self.single_iter(alpha)
def predict(self):
#输出原型向量
return self.p
主函数
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
import numpy as np
from LVQ_class import LVQ
import sys
from Kmeas_class import Kmeans
np.set_printoptions(suppress=True)
#数据集生成
x, y = make_blobs(n_samples=500, n_features=2, centers=4, random_state=np.random.randint(0,30))
color=['yellow','gray','blue','red']
#真实数据集可视化
plt.figure(1)
for i in range(4):
plt.scatter(x[y==i, 0], x[y==i,1],marker='o',s=8,c=color[i])
plt.title('Real Data')
plt.savefig('E:\python\ml\ml by myself\LVQ\LVQ_real_myslef.png')
#使用kmeans算法,输出预测簇
kmeans=Kmeans(x,4)
y_kmeans_predict=kmeans.fit()
#kmeans预测结果可视化
for i in range(4):
plt.scatter(x[y_kmeans_predict == i, 0], x[y_kmeans_predict == i, 1], marker='o', s=8, c=color[i])
plt.title('Keams Predict Result')
plt.savefig('E:\python\ml\ml by myself\LVQ\LVQ_kmeans_predict.png')
#LVQ模型建立与训练
lvq=LVQ(x,y,4,[0,1,2,3])
lvq.fit(iter_count=4000)
y_lvq_predict=lvq.predict()
#LVQ算法结果可视化
plt.figure(1)
for i in range(4):
plt.scatter(x[y==i, 0], x[y==i,1],marker='o',s=8,c=color[i])
for p in y_lvq_predict:
plt.scatter(p[0], p[1],marker='o',s=50,c='green')
plt.title('LVQ Predict Result')
plt.savefig('E:\python\ml\ml by myself\LVQ\LVQ_predict_myslef.png')
sys.exit(0)
结果
真实数据集
![]()
Kmeans预测结果
![]()
LVQ预测结果(绿色的为原型向量)
![]()