线性判别分析(fisher)

线性判别分析

线性判别分析中有降维,把数据都投影到同一条线上,然后在直线上取一个阈值,将直线分成两条射线,每一条代表一个分类。会损失一些数据信息,但如果这些信息是一些干扰信息,丢失也未尝不是好事。

线性判别分析之后的结果是一个向量,其他的不行吗?

主要指导思想(目标):类内小,类间大。

公式推导

我们得到的是向量,为了方便计算损失,不妨设||ww||=1,每一个数据XXi看作一个向量。那么XXiww是每个数据在ww方向上的投影。与ww的其中一个平面是划分平面。

两个不同类别分别命名为C1C2,用μμ,μμC1, μμC2分别代表全部数据,C1数据,C2数据的均值,用ΣΣ,ΣΣC1, ΣΣC2分别代表全部数据,C1数据,C2数据的协方差矩阵。
μ~σ~2表示投影的均值和方差。

μμ=1N1NXXiμμC1=1NC11NC1XXC1iμμC2=1NC21NC2XXC1i ΣΣ=1N1N(XXiμμ)(XXiμμ)TΣΣC1=1NC11NC1(XXC1iμμC1)(XXC1iμμC1)TΣΣC2=1NC21NC2(XXC2iμμC2)(XXC2iμμC2)T

μ~=1N1NXXiθθμ~C1=1NC11NC1XXC1iθθμ~C2=1NC21NC2XXC2iθθσ~2=1N1N(XXiθθμ~)2σ~C12=1NC11NC1(XXC1iθθμ~C1)2σ~C22=1NC21NC2(XXC2iθθμ~C2)2

类间:(μ~C1μ~C2)2

类内:σ~C12+σ~C22

目标函数:J(θθ)=(μ~C1μ~C2)2σ~C12+σ~C22

J(θθ)=(μ~C1μ~C2)2σ~C12+σ~C22=(μ~C1μ~C2)2=(1NC11NC1XXC1iθθ1NC21NC2XXC2iθθ)2=((μμC1μμC2)θθ)2=θθT(μμC1μμC2)T(μμC1μμC2)θθσ~C12=1NC11NC1(XXC1iθθμ~C1)2=1NC11NC1(XXC1iθθ1NC11NC1XXC1iθθ)2=1NC11NC1((XXC1i1NC11NC1XXC1i)θθ)2=1NC11NC1((XXC1iμμC1)θθ)2=1NC11NC1θθT(XXC1iμμC1)T(XXC1iμμC1)θθ=θθT(1NC11NC1(XXC1iμμC1)T(XXC1iμμC1))θθ=θθTΣΣC1θθσ~C22=θθTΣΣC2θθ=θθTΣΣC1θθ+θθTΣΣC2θθ=θθT(ΣΣC1+ΣΣC2)θθ

J(θθ)=θθT(μμC1μμC2)T(μμC1μμC2)θθθθT(ΣΣC1+ΣΣC2)θθ

Sb=(μμC1μμC2)T(μμC1μμC2),Sw=ΣΣC1+ΣΣC2

Sb就是类内方差

Sw就是类间方差

此时J(θθ)=θθTSbθθθθTSwθθ

求导

J(θθ)θθ=θθTSSbθθθθTSSwθθθθ=(θθTSSbθθ(θθTSSwθθ)1)θθ=(θθTSSbθθ)θθ(θθTSSwθθ)1+θθTSSbθθ((θθTSSwθθ)1)θθ=2θθTSSb(θθTSSwθθ)1+θθTSSbθθ(1(θθTSSwθθ)2)(2θθTSSw)

令导数等于零

00=2SSbθθ(θθTSSwθθ)1+θθTSSbθθ(1(θθTSSwθθ)2)(2SSwθθ)2SSbθθ(θθTSSwθθ)1=θθTSSbθθ(1(θθTSSwθθ)2)(2SSwθθ)SSbθθ(θθTSSwθθ)=(θθTSSbθθ)SSwθθ(θθTSSbθθ)SSwθθ=SSbθθ(θθTSSwθθ)θθ=SSw1θθTSSwθθθθTSSbθθSSbθθθθ=SSw1θθTSSwθθθθTSSbθθ(μμC1μμC2)T(μμC1μμC2)θθ


θθTSSwθθθθTSSbθθ,(μμC1μμC2)θθ是一个数,不影响θθ的方向


θθSSw1(μμC1μμC2)T

ifSSwII

θθ(μμC1μμC2)T

求任意一个点的投影

projθθ(x)=xTθθ

求阈值

threshold=NC1μ~C1+NC2μ~C1NC1+NC2=NC1μμC1θθ+NC2μμC1θθNC1+NC2

依赖

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

人工数据集

n = 100
X = np.random.multivariate_normal((1, 1), [[0.64, 0], [0, 0.64]], size = int(n/2))
X = np.insert(X, 50, np.random.multivariate_normal((3, 3), [[0.64, 0], [0,0.64]], size = int(n/2)),0)
#X = np.insert(X, 0, 1, 1)
m = X.shape[1]
y = np.array([1]*50+[-1]*50).reshape(-1,1)
plt.scatter(X[:50, -2], X[:50, -1])
plt.scatter(X[50:, -2], X[50:, -1], c = "#ff4400")
<matplotlib.collections.PathCollection at 0x7f2b50e680d0>

image

X1 = X[(y==1).reshape(-1)]
X0 = X[(y==-1).reshape(-1)]
n1 = np.array([[X1.shape[0]]])
n0 = np.array([[X0.shape[0]]])
mu1 = X1.mean(axis = 0).reshape(-1,1)
mu0 = X0.mean(axis = 0).reshape(-1,1)
Sigma1 = np.cov(X1.T)
Sigma0 = np.cov(X0.T)
theta = (Sigma1 + Sigma0) @ (mu1 - mu0)
threshold = (n1*mu1 + n0*mu0).T@theta/(n1 + n0)
def getForecast(x):
return x.T @ theta
threshold
array([[-10.45793931]])

预测

print(f'{ 1 if getForecast(np.array([[1],[1]])) > threshold else 0}')
1

分界展示

plt.scatter(X[:50, -2], X[:50, -1])
plt.scatter(X[50:, -2], X[50:, -1], c = "#ff4400")
for i in np.arange(-1,5,0.02):
for j in np.arange(-1,5,0.02):
if abs(getForecast(np.array([[i],[j]])) - threshold) <0.01:
plt.scatter(i,j,c="#000000")

image

posted @   孑然520  阅读(371)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· AI 智能体引爆开源社区「GitHub 热点速览」
· C#/.NET/.NET Core技术前沿周刊 | 第 29 期(2025年3.1-3.9)
· 从HTTP原因短语缺失研究HTTP/2和HTTP/3的设计差异
点击右上角即可分享
微信分享提示