编程实现线性判别分析,并给出西瓜数据集3.0α上的运行结果
1.题目理解
将西瓜数据集的样例投影到一条直线上,使得好瓜、坏瓜各自的投影点尽可能接近,好瓜与坏瓜之间的投影点尽可能远离。
2.算法原理
3.算法设计
① 根据LDA原理求解得到w,结合数据集得到LDA直线;
② 将每个样本映射到LDA直线上,观察分析结果。
4.关键代码
1 # 加载数据集 2 dataset = np.loadtxt('C:/Users/86185/PycharmProjects/ML1/watermelon_3a.csv', delimiter=",") 3 4 # 分离属性值和标签 5 X = dataset[:,1:3] 6 y = dataset[:,3] 7 u = [] 8 for i in range(2): 9 u.append(np.mean(X[y==i],axis=0)) 10 11 m,n = np.shape(X) 12 Sw = np.zeros((n,n)) 13 for i in range(m): 14 x_temp = X[i].reshape(n, 1) # 行向量变为列向量 15 if y[i]==0: u_temp = u[0].reshape(n, 1) 16 if y[i]==1: u_temp = u[1].reshape(n, 1) 17 Sw +=np.dot(x_temp-u_temp, (x_temp-u_temp).T) 18 19 Sw = np.mat(Sw) 20 # print(Sw) 21 Sw_inv = np.linalg.inv (Sw) 22 # print(Sw_inv) 23 w = np.dot(Sw_inv, (u[0]-u[1]).reshape(n,1)) 24 print(w)
先根据公式求得w
1 def GetPoint(point0, w): 2 k0 = w[1, 0]/w[0, 0] 3 k1 = w[0, 0]/w[1, 0] 4 x0 = point0[0] 5 y0 = point0[1] 6 x1 = (k0 * x0 - y0) / (k0 + k1) 7 y1 = -k1 * x1 8 return x0, x1, y0, y1 9 10 f1 = plt.figure('first') 11 plt.xlim( -0.2, 1 ) # 设定坐标轴的范围 12 plt.ylim( -0.2, 0.6 ) 13 14 15 x = np.arange(-1, 3) 16 yy = -(w[0,0]/w[1,0])*x
做LDA直线yy;GetPoint()函数用来计算点到直线yy的投影
5.结果展示
根据运行结果显示,没有很明确的将好瓜与坏瓜区分开来,好瓜与坏瓜的投影点不够远离,坏瓜与坏瓜之间的投影点不够聚集。
· 阿里巴巴 QwQ-32B真的超越了 DeepSeek R-1吗?
· 10年+ .NET Coder 心语 ── 封装的思维:从隐藏、稳定开始理解其本质意义
· 【设计模式】告别冗长if-else语句:使用策略模式优化代码结构
· 字符编码:从基础到乱码解决
· 提示词工程——AI应用必不可少的技术