EM算法求解三枚硬币模型的详细推导
问题原型
假设有三枚硬币,记为A,B,C,这三枚硬币出现正面的概率分别是\(\pi\),\(p\)和\(q\)。在掷硬币实验过程中,先掷硬币A,如果其结果为正面,则选择硬币B,反面则选择C;然后掷选中的硬币,记录其出现的结果。独立地重复\(n\)次实验,我们得到一个观测结果,比如说\(1,1,0,1,0,0,1,0,1,1\)。假设只能观测到掷硬币的结果,不知道掷硬币的过程,求解三枚硬币出现正面的概率,即求解\(\pi\),\(p\)和\(q\)。
使用EM算法求解
\(E\)步:假设\(\theta^{(t)}\)为第\(t\)次迭代参数\(\theta\)的估计值,在第\(t+1\)次迭代的\(E\)步,计算\(Q\)函数:
\[\begin{align}
Q(\theta,\theta^{(t)})&=\sum_{Z}{P(Z|Y,\theta^{(t)})\log{P(Y,Z|\theta)}}\\
&=\sum_{Z}\log P(Y,Z|\theta)^{P(Z|Y,\theta^{(t)})}\\
&=\sum_{Z}\log \prod_{i=1}^{n}P(y_{i},Z|\theta)^{P(Z|y_{i},\theta^{(t)})}\\
&=\sum_{Z}\sum_{i=1}^{n}\log P(y_{i},Z|\theta)^{P(Z|y_{i},\theta^{(t)})}\\
&=\sum_{i=1}^{n}\sum_{Z}{P(Z|y_{i},\theta^{(t)})\log{P(y_{i},Z|\theta)}}\\
&=\sum_{i=1}^{n}\{P(Z=0|y_{i},\theta^{(t)})\log{P(y_{i},Z=0|\theta)}+P(Z=1|y_{i},\theta^{(t)})\log{P(y_{i},Z=1|\theta)}\}
\end{align}
\]
上式中,\(P(y_{i},Z=0|\theta)=\pi p^{y_{i}}(1-p)^{1-y_{i}}\),\(P(y_{i},Z=1|\theta)=(1-\pi)q^{y_{i}}(1-q)^{1-y_{i}}\)
下面需要求\(P(Z=0|y_{i},\theta^{(t)})\)和\(P(Z=1|y_{i},\theta^{(t)})\)
根据联合概率、边缘概率与条件概率的关系得到:
\[\begin{align}
P(Z|y_{i},\theta^{(t)})=\frac{P(Z,y_{i}|\theta^{(t)})}{P(y_{i}|\theta^{(t)})}
\end{align}
\]
另外我们有:
\[\begin{align}
P(y_{i}|\theta^{(t)})=\pi^{(t)} (p^{(t)})^{y_{i}}(1-p^{(t)})^{1-y_{i}}+(1-\pi^{(t)})(q^{(t)})^{y_{i}}(1-q^{(t)})^{1-y_{i}}
\end{align}
\]
所以可以得到:
\[\begin{align}
P(Z=0|y_{i},\theta^{(t)})&=\frac{\pi^{(t)} (p^{(t)})^{y_{i}}(1-p^{(t)})^{1-y_{i}}}{\pi^{(t)} (p^{(t)})^{y_{i}}(1-p^{(t)})^{1-y_{i}}+(1-\pi^{(t)})(q^{(t)})^{y_{i}}(1-q^{(t)})^{1-y_{i}}}=\mu_{i}^{(t)}\\
P(Z=1|y_{i},\theta^{(t)})&=1-\mu_{i}^{(t)}
\end{align}
\]
代入到\(Q(\theta,\theta^{(t)})\)中,得到:
\[\begin{align}
Q(\theta,\theta^{(t)})&= \sum_{i=1}^{n}{\mu_{i}^{(t)}\log{P(y_{i},Z=0|\theta)}+(1-\mu_{i}^{(t)})\log{P(y_{i},Z=1|\theta)}}\\
&= \sum_{i=1}^{n}{\{\mu_{i}^{(t)}\log{\pi p^{y_{i}}(1-p)^{1-y_{i}}}+(1-\mu_{i}^{(t)})\log {(1-\pi)q^{y_{i}}(1-q)^{1-y_{i}}}\}}
\end{align}
\]
\(M\)步
于是,我们得到目标函数:
\[\begin{align}
\theta^{(t+1)}=\mathop{\arg\max}_{\theta}Q(\theta, \theta^{(t)})
\end{align}
\]
然后对\(Q(\theta, \theta^{(t)})\)求偏导就可以求出\(\pi^{(t+1)}\),\(p^{(t+1)}\)和\(q^{(t+1)}\)
\[\begin{align}
\frac{\partial Q(\theta,\theta^{(t)})}{\partial \pi}&=\sum_{i=1}^{n}{\{\mu_{i}^{(t)}\frac{p^{y_{i}}(1-p)^{1-y_{i}}}{\pi p^{y_{i}}(1-p)^{1-y_{i}}}-(1-\mu_{i}^{(t)})\frac{q^{y_{i}}(1-q)^{1-y_{i}}}{(1-\pi)q^{y_{i}}(1-q)^{1-y_{i}}}\}}\\
&=\sum_{i=1}^{n}{\{\frac{\mu_{i}^{(t)}}{\pi}-\frac{1-\mu_{i}^{(t)}}{1-\pi}\}}\\
&=\sum_{i=1}^{n}{\{\frac{\mu_{i}^{(t)}-\pi}{\pi(1-\pi)}\}}\\
&\mathop{=}^{let}0\\
&\Rightarrow \sum_{i=1}^{n}\frac{\mu_{i}^{(t)}}{\pi(1-\pi)}=\sum_{i=1}^{n}\frac{\pi}{\pi(1-\pi)}\\
&\Rightarrow \pi^{(t+1)}=\frac{1}{n}\sum_{i=1}^{n}\mu_{i}^{(t)}
\end{align}
\]
\[\begin{align}
\frac{\partial Q(\theta, \theta^{(t)})}{\partial p}&=\sum_{i=1}^{n}{\{\mu_{i}^{(t)}\frac{\pi y_{i}p^{y_{i}-1}(1-p)^{1-y_{i}}+\pi p^{y_{i}}(1-y_{i})(1-p)^{-y_{i}}(-1)}{\pi p^{y_{i}}(1-p)^{1-y_{i}}}\}}\\
&=\sum_{i=1}^{n}{\{\mu_{i}^{(t)}\frac{\pi p^{y_{i}}(1-p)^{1-y_{i}}[y_{i}p^{-1}-(1-y_{i})(1-p)^{-1}]}{\pi p^{y_{i}}(1-p)^{1-y_{i}}}\}}\\
&=\sum_{i=1}^{n}{\{\mu_{i}^{(t)}\frac{y_{i}-p}{p(1-p)}\}}\\
&\mathop{=}^{let}0\\
&\Rightarrow p^{(t+1)}=\frac{\sum_{i=1}^{n}y_{i}\mu_{i}^{(t)}}{\sum_{i=1}^{n}\mu_{i}^{(t)}}
\end{align}
\]
\[\begin{align}
\frac{\partial Q(\theta, \theta^{(t)})}{\partial q}&=\sum_{i=1}^{n}{\{(1-\mu_{i}^{(t)})[y_{i}q^{-1}-(1-y_{i})(1-q)^{-1}]\}}\\
&=\sum_{i=1}^{n}{\{(1-\mu_{i}^{(t)})\frac{y_{i}-q}{q(1-q)}\}}\\
&\mathop{=}^{let}0\\
&\Rightarrow q^{(t+1)}=\frac{\sum_{i=1}^{n}(1-\mu_{i}^{(t)})y_{i}}{\sum_{i=1}^{n}(1-\mu_{i}^{(t)})}
\end{align}
\]
代码实现
主函数
int main()
{
shared_ptr<ThreeCoin> em;
em.reset(new ThreeCoin());
em->input(vector<int>{ 1, 1, 0, 1, 0, 0, 1, 0, 1, 1 }, 0.4, 0.6, 0.7);
em->calc(1e-6, 2);
em->printProbability();
system("pause");
}
具体实现
struct ThreeCoin
{
//three coins
float init_pi;
float init_p;
float init_q;
float final_pi;
float final_p;
float final_q;
int n;
vector<int> observes;
void input(vector<int>& obs, float pi, float p, float q);
void printProbability();
void calc(float epsilon, int max_iter);
};
void ThreeCoin::input(vector<int>& obs, float pi, float p, float q)
{
init_pi = pi;
init_p = p;
init_q = q;
n = obs.size();
observes.resize(n);
for (int i = 0; i < n; ++i) observes[i] = obs[i];
}
void ThreeCoin::printProbability()
{
cout << "initialized pi is: " << fixed << setprecision(5) << init_pi << endl;
cout << "initialized p is: " << fixed << setprecision(5) << init_p << endl;
cout << "initialized q is: " << fixed << setprecision(5) << init_q << endl;
cout << "predicted pi is: " << fixed << setprecision(5) << final_pi << endl;
cout << "predicted p is: " << fixed << setprecision(5) << final_p << endl;
cout << "predicted q is: " << fixed << setprecision(5) << final_q << endl;
}
void ThreeCoin::calc(float epsilon, int max_iter)
{
float dis = 0.f;
int iter = 0;
float _pi = init_pi;
float _p = init_p;
float _q = init_q;
do
{
float next_pi = 0.f, next_p = 0.f, next_q = 0.f;
for (int i = 0; i < n; ++i)
{
float miu = _pi*(observes[i] ? _p : 1 - _p);
miu = miu / (miu + (1 - _pi)*(observes[i] ? _q : 1 - _q));
next_pi += miu;
next_p += miu*observes[i];
next_q += (1 - miu)*observes[i];
}
next_p /= next_pi;
next_q /= n - next_pi;
next_pi /= n;
dis = max(abs(_pi - next_pi), abs(_p - next_p));
dis = max(dis, abs(_q - next_q));
_pi = next_pi;
_p = next_p;
_q = next_q;
++iter;
} while (dis>=epsilon && iter < max_iter);
final_pi = _pi;
final_p = _p;
final_q = _q;
}