高斯判别分析GDA推导与代码实现

高斯判别分析GDA推导与代码实现

生成学习

处理分类问题,我们可以使用逻辑回归、Softmax。这两种方法都属于“判别学习”,也就是给定 (x(i),y(i)),我们学习 P(y|x),并对于给定的 x,计算 argmaxy{P(y|x)}

GDA属于另一种方法——生成学习。在判别学习中,我们并不关注 x 本身的分布,而在生成学习中,我们基于一些事实假设 x 的分布,例如在GDA中,我们假设对于相同的 yx 符合高斯分布。然后基于假设,学习 P(x|y) 的参数(高斯分布则学习均值 μ 和协方差 Σ),以及 P(y) 的参数(例如两点分布的 ϕ),从而得到联合概率 P(x,y)。在预测时,我们仍然是找到最大的 P(y|x),但是是用贝叶斯公式计算:

P(y|x)=P(x,y)P(x)

高斯判别分析

作出的假设

首先假设 Y 服从两点分布 Bernolli(ϕ),于是可以写作 P(y)=ϕy(1ϕ)1y

(尽管可以直接用协方差定义多元变量的高斯分布,但是这里采用另一种方法,在特殊的情况下得到等式而不需要完全理解协方差矩阵。)

然后假设给定 y 后,x每一个分量 xi 都服从高斯分布且相互独立。不妨设 y=0,即 xi|y=0N(μi,σi)

P(xi|y=0)=12πσiexp((xiμi)22σi2)

xi 相互独立,可以得到:

P(x|y=0)=i=1n12πσiexp((xiμi)22σi2)=1(2π)n2σ1σnexp(12i=1n(xiμi)2σi2)

我们定义矩阵 Σ=(σ12σn2),可以将上式化简:

P(x|y=0)=1(2π)n2|Σ|12exp(12(xμ)Σ1(xμ)T)

同样,我们也假设 x|y=1 符合高斯分布 N(μ,Σ),需要注意的是,两个分布都采用了同一个 Σ(存疑,不知道目的)。

但是实际上GDA并没有假设 Σ

最大似然估计

在判别学习中,我们以 P(y(i)xi) 为似然函数,而在GDA这一类生成学习中,我们以联合概率 P(x(i),y(i)) 为似然函数。也即最大化如下对数似然函数:

L(ϕ,μ,μ,Σ)=lni=1mP(x(i),y(i);ϕ,μ,μ,Σ)=lni=1mP(x(i)y(i);μ,μ,Σ)P(y(i);ϕ)=i=1mlnP(x(i)y(i);μ,μ,Σ)+i=1mlnP(y(i);ϕ)

不同于逻辑回归,我们可以直接用导数为 0 求解参数。

计算 ϕ

Lϕ=ϕi=1mlnP(y(i);ϕ)=ϕi=1my(i)lnϕ+(1y(i))ln(1ϕ)=i=1my(i)ϕ1y(i)1ϕ

令上式为 0,得 ϕ=y(i)m

计算 μ,μ

一些无关紧要的常数用 C 代替,比如密度函数中的系数。

μL=μi=1m(1y(i))lnP(x(i)y=0;μ,Σ)+y(i)lnP(x(i)y=1;μ,Σ)

其中

μjlnP(x(i)y=0;μ,Σ)=μj(C12ln|Σ|12(x(i)μ)Σ1(x(i)μ)T)=xj(i)μj

μlnP(x(i)y=0;μ,Σ)=x(i)μ,进而

μL=i=1m(1y(i))(x(i)μ)

令上式为 0,可得 μ=(1y(i))x(i)1y(i),同理可得 μ=y(i)x(i)y(i)

计算 Σ

下面用到了两个 算符的性质:

  • Σ|Σ|=|Σ|Σ1
  • ΣΣ1=Σ2

ΣL=i=1m12|Σ|Σ|Σ|+12Σ((x(i)μ)Σ1(x(i)μ)T)=i=1m12Σ1+12(x(i)μ)T(x(i)μ)Σ2

同样地令上式为 0,计算得

Σ=1mi=1m(x(i)μ)T(x(i)μ)

注意到,按照我们的假设,Σ 应该是一个对角矩阵(如果了解协方差矩阵的话,由 xi 相互独立,可以推出 Σ 应该是对角矩阵,对角元就是每个变量的方差),但是这里非常显然不总是对角矩阵。

(以下仅是我个人的猜测)在GDA的实现中,我们并没有关注 xi 相互独立的性质,而是直接学习了一个普遍的协方差矩阵。实际上这种定义的限制更低,更符合现实情况(现实中的变量之间存在联系比较普遍)。

代码实现

直接计算这几个参数的代码不需要解释:

Copym, n = xs.shape
# Calculate mu
self.mu = np.zeros((2, n))
classes_size = np.zeros(2)
for i in range(m):
    self.mu[ys[i]] += xs[i]
    classes_size[ys[i]] += 1
self.mu /= np.transpose([classes_size])
# Calculate Sigma
self.sigma = np.zeros((n, n))
for i in range(m):
    temp_array = xs[i] - self.mu[ys[i]]
    self.sigma += np.dot(temp_array.reshape(n, 1), temp_array.reshape(1, n))
self.sigma /= m
# Calculate phi
self.phi = np.sum(ys) / m

最后我们发现计算概率需要计算 Σ1,然而我们算出来的可能是一个奇异矩阵。这时候可以对 Σ 进行微扰——将对角元加上一个较小的偏差。

最后我们对 y=0,1 分别计算出 P(x|y)

P(x|y;μ,Σ)=1(2π)n2|Σ|12exp((xμ)Σ1(xμ)T2)

但是实际上比较 P(x|y=0),P(x|y=1) 不需要真正计算出概率,而是比较两者不同的地方,也就是后面的 exp

(奇异矩阵的处理参考 https://blog.csdn.net/qq_30091945/article/details/81508055 ,对其中的 Gaussian 函数有修改)

最后是实现的一个类:

Copyclass GaussianDiscriminantAnalysis:
    def __init__(self):
        self.mu = None
        self.phi = None
        self.sigma = None

    def fit(self, xs, ys, **others):
        m, n = xs.shape
        # Calculate mu
        self.mu = np.zeros((2, n))
        classes_size = np.zeros(2)
        for i in range(m):
            self.mu[ys[i]] += xs[i]
            classes_size[ys[i]] += 1
        self.mu /= np.transpose([classes_size])
        # Calculate Sigma
        self.sigma = np.zeros((n, n))
        for i in range(m):
            temp_array = xs[i] - self.mu[ys[i]]
            self.sigma += np.dot(temp_array.reshape(n, 1), temp_array.reshape(1, n))
        self.sigma /= m
        # Calculate phi
        self.phi = np.sum(ys) / m

    def evaluate(self, x, mean, cov):
        dim = np.shape(cov)[0]
        # cov的行列式为零时的措施
        cov_inv = np.linalg.inv(cov + np.eye(dim) * 0.001)
        xdiff = (x - mean).reshape((1, dim))
        # 概率密度
        prob = np.exp(-0.5 * xdiff.dot(cov_inv).dot(xdiff.T))[0][0]
        return prob

    def predict(self, xs):
        predict = []
        for x in xs:
            evaluate_0 = self.evaluate(x, self.mu[0], self.sigma) * (1 - self.phi)
            evaluate_1 = self.evaluate(x, self.mu[1], self.sigma) * self.phi
            if evaluate_0 > evaluate_1:
                predict.append(0)
            else:
                predict.append(1)
        return predict
posted @   Lucky_Glass  阅读(170)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· AI 智能体引爆开源社区「GitHub 热点速览」
· 从HTTP原因短语缺失研究HTTP/2和HTTP/3的设计差异
· 三行代码完成国际化适配,妙~啊~
历史上的今天:
2021-02-27 「SOL」谢特(LOJ)
TOP BOTTOM
点击右上角即可分享
微信分享提示