EM(最大期望)算法推导、GMM的应用与代码实现
EM算法是一种迭代算法,用于含有隐变量的概率模型参数的极大似然估计。
1 使用EM算法的原因#
首先举李航老师《统计学习方法》中的例子来说明为什么要用EM算法估计含有隐变量的概率模型参数。
假设有三枚硬币,分别记作A, B, C。这些硬币正面出现的概率分别是π,p,q。进行如下掷硬币试验:先掷硬币A,根据其结果选出硬币B或C,正面选硬币B,反面边硬币C;然后掷选出的硬币,掷硬币的结果出现正面记作1,反面记作0;独立地重复n次试验,观测结果为{y1,y2,...,yn}。问三硬币出现正面的概率。
三硬币模型,也就是第二枚硬币为正面或反面的概率(y=1表示正面,y=0表示反面),或者说观测变量的概率,可以写作
P(y|π,p,q)=∑zP(y,z|π,p,q)=∑zP(y|z,π,p,q)P(z|π,p,q)=πpy(1−p)1−y+(1−π)qy(1−q)1−y
其中z表示硬币A的结果,也就是前面说的隐变量。为了求得参数π,p,q,我们通常会使用极大似然估计,即最大化似然函数
maxπ,p,qn∏i=1P(yi|π,p,q)=maxπ,p,qn∏i=1[πpyi(1−p)1−yi+(1−π)qyi(1−q)1−yi]=maxπ,p,qn∑i=1log[πpyi(1−p)1−yi+(1−π)qyi(1−q)1−yi]=maxπ,p,qL(π,p,q)
分别对π,p,q求偏导并等于0,求解方程组来估计这三个参数。但是,由于它是带有隐变量的,在计算最终的概率之前有一个分支选择的过程,导致这个log的内部是加和的形式,不但计算导数十分困难,待求解的方程组还不是线性方程组。当复杂度一高,解这种方程组几乎成为不可能的事。以下推导EM算法,它以迭代的方式来求解这些参数,它包含了一种“贪心”的思想。
2 算法导出与理解#
对于参数为θ且含有隐变量Z的概率模型,进行n次抽样。假设随机变量Y的观察值为Y={y1,y2,...,yn},隐变量Z的m个可能的取值为Z={z1,z2,...,zm}。
写出似然函数:
L(θ)=∑Y∈YlogP(Y|θ)=∑Y∈Ylog∑Z∈ZP(Y,Z|θ)
EM算法首先初始化参数θ=θ0,然后每一步迭代都会使似然函数增大,即L(θk+1)≥L(θk)。如何做到不断变大呢?考虑第k+1步迭代似然函数(这一步很重要!):
L(θ)=∑Y∈Ylog∑Z∈ZP(Y,Z|θ)=∑Y∈Ylog∑Z∈ZP(Z|Y,θk)P(Y,Z|θ)P(Z|Y,θk)
至于上式的第二个等式为什么取出P(Z|Y,θk)而不是别的,正向的原因我想不出来,马后炮原因在后面记录。
考虑其中的求和
∑Z∈ZP(Z|Y,θk)=1
且由于log函数是凹函数,因此由Jenson不等式得
L(θ)≥∑Y∈Y∑Z∈ZP(Z|Y,θk)logP(Y,Z|θ)P(Z|Y,θk)=B(θ,θk)
当θ=θk时,有
L(θk)≥B(θk,θk)=∑Y∈Y∑Z∈ZP(Z|Y,θk)logP(Y,Z|θk)P(Z|Y,θk)=∑Y∈Y∑Z∈ZP(Z|Y,θk)logP(Y|θk)=∑Y∈YlogP(Y|θk)=L(θk)
也就是在这时,(3)式取等,即L(θk)=B(θk,θk)。另取
θ∗=argmaxθB(θ,θk)
可得不等式
L(θ∗)≥B(θ∗,θk)≥B(θk,θk)=L(θk)
所以,我们只要优化(5)式,让θk+1=θ∗,即可保证每次迭代的非递减势头,有L(θk+1)≥L(θk)。而由于似然函数是概率乘积的对数,一定有L(θ)<0,所以迭代有上界并且会收敛。以下是《统计学习方法》中EM算法一次迭代的示意图:
进一步简化(5)式,去掉优化无关项:
θ∗=argmaxθB(θ,θk)=argmaxθ∑Y∈Y∑Z∈ZP(Z|Y,θk)logP(Y,Z|θ)P(Z|Y,θk)=argmaxθ∑Y∈Y∑Z∈ZP(Z|Y,θk)logP(Y,Z|θ)=argmaxθQ(θ,θk)
Q函数的对数内部没有像(1)式一样的和式,使用导数求极值的方程就与没有隐变量的方程类似了,容易求解。另外,Q函数还可以写成期望的形式(书上是不带Y的求和的,我觉得加上更严谨一些,也容易理解一些):
Q(θ,θk)=∑Y∈YEZ∈Z[logP(Y,Z|θ)|Y,θk]
综上,EM算法的流程为:
1. 设置θ0的初值。EM算法对初值是敏感的,不同初值迭代出来的结果可能不同。可以观察上面的示意图,如果θk在左边的峰值附近,EM最终就会迭代到左边的局部最优,无法发现右边更大的值。
2. 更新θk=argmaxθQ(θ,θk−1)。理解上来说,通常将这一步分为计算Q与极大化Q两步,即求期望E与求极大M,但在代码中并不会将它们分出来,因此这里浓缩为一步。另外,如果这个优化很难计算的话,因为有不等式的保证,可以直接取θk为某个ˆθ,只要有Q(ˆθ,θk−1)≥Q(θk−1,θk−1)即可。
3. 比较θk与θk−1的差异,比如求它们的差的二范数,若小于一定阈值就结束迭代,否则重复步骤2。
下面记录一下我对(1)式取出P(Z|Y,θk)而不取别的P的理解:
经过以上的推导,我认为这是为了给不等式取等创造条件。如果不能确定L(θk)与Q(θk,θk)能否取等,那么取Q的最大值Q(θ∗,θk)时,尽管有Q(θ∗,θk)≥Q(θk,θk),但并不能保证L(θ∗)≥L(θk),迭代的不减性质就就没了。
我这里暂且把它看做一种巧合,是研究EM算法的大佬,碰巧想用Jenson不等式来迭代而构造出来的一种做法。本人段位还太弱,无法正向理解其中的缘故,只能以这种方式来揣度大佬的思路了。知乎大佬发的EM算法九层理解(点击链接),我当前只能到第3层,有时间一定要拜读一下深度学习之父的著作。
3 高斯混合模型的应用#
3.1 迭代式推导#
假设高斯混合模型混合了m个高斯分布,参数为θ=(α1,θ1,α2,θ2,...,αm,θm),θi=(μi,σi)则整个概率密度为:
P(y|θ)=m∑i=1αiϕ(y|θi)=m∑i=1αi√2πσiexp(−(y−μi)22σ2i),wherem∑j=1αj=1
对混合分布抽样n次得到{y1,...,yn},则在第k+1次迭代,待优化式为:
maxθQ(θ,θk)=maxθ∑Y∈Y∑Z∈ZP(Z|Y,θk)logP(Y,Z|θ)=maxθ∑Y∈Y∑Z∈ZP(Z,Y|θk)P(Y|θk)logP(Y,Z|θ)=maxθn∑i=1m∑j=1αkjϕ(yi|θkj)m∑l=1αklϕ(yi|θkl)log[αjϕ(yi|θj)]=maxθn∑i=1m∑j=1αkjϕ(yi|θkj)m∑l=1αklϕ(yi|θkl)log[αj√2πσjexp(−(yi−μj)22σ2j)]=maxθm∑j=1n∑i=1αkjϕ(yi|θkj)m∑l=1αklϕ(yi|θkl)[logαj−logσj−(yi−μj)22σ2j]
3.1.1 计算α#
定义
cj=n∑i=1αkjϕ(yi|θkj)m∑l=1αklϕ(yi|θkl)
则对于α,优化式为
maxαm∑j=1cjlogαj
又因为m∑j=1αj=1,所以只需优化m−1个参数,上式变为:
maxα[c1c2⋯cm−1cm]⋅[logα1logα2⋮logαm−1log(1−α1−⋯−αm−1)]
对每个αj求导并等于0,得到线性方程组:
[c1+cmc1c1⋯c1c2c2+cmc2⋯c2c3c3c3+cm⋯c3⋮cm−1cm−1cm−1⋯cm−1+cm]⋅[α1α2α3⋮αm−1]=[c1c2c3⋮cm−1]
求解这个爪形线性方程组,得到
αj=cj∑mi=1ci
因为
m∑j=1cj=m∑j=1n∑i=1αkjϕ(yi|θkj)m∑l=1αklϕ(yi|θkl)=n∑i=1m∑j=1αkjϕ(yi|θkj)m∑l=1αklϕ(yi|θkl)=n∑i=11=n
解得
αj=cjn=1nn∑i=1αkjϕ(yi|θkj)m∑l=1αklϕ(yi|θkl)
3.1.2 计算σ与μ#
与α不同,它的方程组是所有αj之间联立的;而σ,μ的方程组则是σj与μj之间联立的。定义
pji=αkjϕ(yi|θkj)m∑l=1αklϕ(yi|θkl)
则对于σj,μj,优化式为
minσj,μjn∑i=1pji(logσj+(yi−μj)22σ2j)
对上式求导等于0,解得
μj=n∑i=1pjiyin∑i=1pji=n∑i=1pjiyicj=n∑i=1pjiyinαjσ2j=n∑i=1pji(yi−μj)2n∑i=1pji=n∑i=1pji(yi−μj)2cj=n∑i=1pji(yi−μj)2nαj
3.2 代码实现#
对于概率密度为P(x)=−2x+2,x∈(0,1)的随机变量,以下代码实现GMM对这一概率密度的的拟合。共10000个抽样,GMM混合了100个高斯分布。
#%%定义参数、函数、抽样 import numpy as np import matplotlib.pyplot as plt dis_num = 100 #用于拟合的分布数量 sample_num = 10000 #用于拟合的分布数量 alphas = np.random.rand(dis_num) alphas /= np.sum(alphas) mus = np.random.rand(dis_num) sigmas = np.random.rand(dis_num)**2#方差,不是标准差 samples = 1-(1-np.random.rand(sample_num))**0.5 #样本 C_pi = (2*np.pi)**0.5 dis_val = np.zeros([sample_num,dis_num]) #每个样本在每个分布成员上都有值,形成一个sample_num*dis_num的矩阵 pij = np.zeros([sample_num,dis_num]) #pij矩阵 def calc_dis_val(sample,alpha,mu,sigma,c_pi): return alpha*np.exp(-(sample[:,np.newaxis]-mu)**2/(2*sigma))/(c_pi*sigma**0.5) def calc_pij(dis_v): return dis_v / dis_v.sum(axis = 1)[:,np.newaxis] #%%优化 for i in range(1000): print(i) dis_val = calc_dis_val(samples,alphas,mus,sigmas,C_pi) pij = calc_pij(dis_val) nj = pij.sum(axis = 0) alphas_before = alphas alphas = nj / sample_num mus = (pij*samples[:,np.newaxis]).sum(axis=0)/nj sigmas = (pij*(samples[:,np.newaxis] - mus)**2 ).sum(axis=0)/nj a = np.linalg.norm(alphas_before - alphas) print(a) if a< 0.001: break #%%绘图 plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签 plt.rcParams['axes.unicode_minus']=False #用来正常显示负号 def get_dis_val(x,alpha,sigma,mu,c_pi): y = np.zeros([len(x)]) for a,s,m in zip(alpha,sigma,mu): y += a*np.exp(-(x-m)**2/(2*s))/(c_pi*s**0.5) return y def paint(alpha,sigma,mu,c_pi,samples): x = np.linspace(-1,2,500) y = get_dis_val(x,alpha,sigma,mu,c_pi) fig = plt.figure() ax = fig.add_subplot(111) ax.hist(samples,density = True,label = '抽样分布') ax.plot(x,y,label = "拟合的概率密度") ax.legend(loc = 'best') plt.show() paint(alphas,sigmas,mus,C_pi,samples)
以下是拟合结果图,有点像是核函数估计,但是完全不同:
4 EM算法的推广#
EM算法的推广是对EM算法的另一种解释,最终的结论是一样的,它可以使我们对EM算法的理解更加深入。它也解释了我在(1)式下方提出的疑问:为什么取出P(Z|Y,θk)而不是别的。
定义F函数,即所谓Free energy自由能(自由能具体是啥先不研究了):
F(˜P,θ)=E˜P(logP(Y,Z|θ))+H(˜P)=∑Z∈Z˜P(Z)logP(Y,Z|θ)−∑Z∈Z˜P(Z)log˜P(Z)
其中˜P是Z的某个概率分布(不一定是单独的分布,可能是在某个条件下的分布),E˜P表示分布˜P下的期望,H表示信息熵。
我们计算一下,对于固定的θ,什么样的˜P会使F(˜P,θ)最大。也就是找到一个函数˜Pθ,使F极大,写成优化的形式就是(这里是找函数而不是找参数哦,理解上可能要用到泛函分析变分法的内容):
max˜P∑Z∈Z˜P(Z)logP(Y,Z|θ)−∑Z∈Z˜P(Z)log˜P(Z)s.t.∑Z∈Z˜P(Z)=1
拉格朗日函数(拉格朗日对偶性,点击链接)为:
L=∑Z∈Z˜P(Z)logP(Y,Z|θ)−∑Z∈Z˜P(Z)log˜P(Z)+λ(1−∑Z∈Z˜P(Z))
因为每个˜P(Z)之间都是求和,没有其它其它诸如乘积的操作,所以可以直接令L对某个˜P(Z)求导等于0来计算极值:
∂L∂˜P(Z)=logP(Y,Z|θ)−log˜P(Z)−1−λ=0
于是可以推出:
P(Y,Z|θ)=e1+λ˜P(Z)
又由约束∑Z∈Z˜P(Z)=1:
P(Y|θ)=e1+λ
于是得到
˜Pθ(Z)=P(Z|Y,θ)
代回F(˜P,θ),得到
F(˜Pθ,θ)=∑Z∈ZP(Z|Y,θ)logP(Y,Z|θ)−∑Z∈ZP(Z|Y,θ)logP(Z|Y,θ)=∑Z∈ZP(Z|Y,θ)logP(Y,Z|θ)P(Z|Y,θ)=logP(Y|θ)
也就是说,对F关于˜P进行最大化后,F就是待求分布的对数似然;然后再关于θ最大化,也就算得了最终要估计的参数ˆθ。所以,EM算法也可以解释为F的极大-极大算法。优化结果(8)式也解释了我之前在(1)式下方的提问。
那么,怎么使用F函数进行估计呢?还是要用迭代来算,迭代方式是和前面介绍的一样的(懒得记录了,统计学习方法上直接看吧)。实际上,F函数的方法只是提供了EM算法的另一种解释,具体方法上并没有提升之处。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· 【自荐】一款简洁、开源的在线白板工具 Drawnix