对比学习损失函数NCE
对于特征相似还有一种理解视角, 就是互信息最大化, 也就是同一个物品不同视角下的特征之间的互信息应该最大化. 这一节将要推导的NCE和下一节将要推导的InfoNCE就是互信息的一种近似估计办法(也叫JSD估计), 为什么不直接计算下面公式展示的互信息了, 因为互信息的计算太过于复杂, 只能近似估计.
\[I(X,Y)=\sum_{x\in X}\sum_{y\in Y}p(x,y)\log\frac{p(x,y)}{p(x)p(y)}
\tag{Eq.1}
\]
假设我们有一系列交互的序列\(s=\{x_1,x_2,...,x_m\}\) 其中\(x_i\in V\) 是序列数据,我们对序列进行建模。根据链式规则我们可以计算出这个序列数据的概率如下所示:
\[p(x_1,x_2,...,x_m)=\prod_{i=1}^mp(x_i|x1,...,x_{i-1})=\prod_{i=1}^mp(x_i|c_i)
\tag{Eq.2}
\]
其中假定\(x_i\)会受到之前电机的序列\(x_{i-1},...,x_1\)的影响,并且把之前电机的序列以上下文\(c\)简化表示。为了公式简洁,我们将公式中的下标\(i\)去除,则\(p(x|c)\)可以通过如下贝叶斯计算公式获得,分子代表上下文\(c\)和物品\(x\)的联合概率公式,分母是一个配分函数。分母作为配分函数也有很高的计算复杂度的问题,在机器学习中一般会使用变分推断等方法来解决。我们希望可以通过神经网络来学习,我们用\(u_\theta\)代表神经网络,其中网络参数为\(\theta\),公式如下:
\[p(x|c)=\frac{p(x,c)}{p(c)}=\frac{p(x,c)}{\sum_xp(c|x)p(x)}\Rightarrow\frac{u_\theta(x,c)}{\sum_{x'}\in Vu_\theta(x',c)}
\tag{Eq.3}
\]
NCE为了解决配分函数难求的问题配分函数中的数据二分类, 分为data和noise, data类别$ D=1\(, 噪声类别\)D=0$, data就是用户交互过的物品, noise就是用户没有交互的负样本, 作为噪声一般假设是均匀分布(当然你可以假设它是一个其他分布也通过神经网络学习). 这样我们可以将配分函数的分布分成数据分布和噪声分布。
我们将数据分布记为\(\tilde p(x|c)\), 噪声分布为\(q(x)\), 噪声一般为均匀分布, 我们可以先验计算正样本有$ k_d$ 个, 负样本有$ k_n$ 个, 上述这些变量组合成为混合分布\(p(x│c)\). 这一段描述可以公式表示成下面四个先验公式:
\[p(D=1)=\frac{k_d}{k_d+k_n}
\]
\[p(D=0)=\frac{k_n}{k_d+k_n}
\]
\[\tilde p(x|c)=p(x|D=1, c)
\]
\[q(x) = p(x|D=0, c)
\]
根据上面四个先验公式, 我们可以求出后验概率, 即在给定上下文和物品的条件下, \(D=1\)的概率,公式如下:
\[\begin{aligned}
p(D=1|x,c) &=\frac{p(D=1)\tilde p(x|c)}{p(D=1)\tilde p(x|c)+p(D=0)q(x)} \\
&= \frac{\frac{k_d}{k_d+k_n}\tilde p(x|c)}{\frac{k_d}{k_d+k_n}\tilde p(x|c)+\frac{k_n}{k_d+k_n} q(x)} \\
&= \frac{\tilde p(x|c)}{\tilde p(x|c) + \frac{k_n}{k_d}q(x)}
\end{aligned}
\tag{Eq.4}
\]
同理,我们可以得到noise的计算概率,二分类的情况下和为1,公式如下:
\[p(D=0|x,c) = \frac{\frac{k_n}{k_d}q(x)}{\tilde p(x|c) + \frac{k_n}{k_d}q(x)}
\tag{Eq.5}
\]
接着我们整理两个狮子,设\(k=\frac{k_n}{k_d}\), 假设训练的模型为\(u_\theta(\cdot)\)函数用于学习数据分布\(\tilde p(x|c)\),则模型计算概率如下
\[p_\theta(D=0|x,c) = \frac{\frac{k_n}{k_d}q(x)}{\tilde p(x|c) + \frac{k_n}{k_d}q(x)}=\frac{kq(x)}{u_\theta(x,c)+kq(x)}
\tag{Eq.6}
\]
\[p_\theta(D=1|x,c) = \frac{\tilde p(x|c)}{\tilde p(x|c) + \frac{k_n}{k_d}q(x)}=\frac{u_\theta(x,c)}{u_\theta(x,c)+kq(x)}
\tag{Eq.7}
\]
至此我们得到了后验概率的最终表示, 又由于我们将配分函数分成两类data和noise, 作为二分类问题我们可以使用交叉熵作为目标函数, 加上负号也就是我们的NCE 损失函数,公式如下:
\[\begin{aligned}
\mathcal J_{NCE}=\sum_{t=1}^{k_d+k_n}\left[ D_t\log p_\theta(D=1|x_t,c_t)+(1-D_t)\log p_\theta(D=0|x_t,c_t)\right]
\end{aligned}
\tag{Eq.8}
\]
NCE Loss性质
那么二分类能否胜任计算配分函数的任务呢? 当\(k=\frac{k_n}{k_d}\)足够大的时候就可以, 即负样本足够多的时候就可以. 推导如下我们可以整理(Eq.8)的目标函数:
\[\begin{aligned}
\mathcal J_{NCE}&=\sum_{t=1}^{k_d+k_n}\left[ D_t\log p_\theta(D=1|x_t,c_t)+(1-D_t)\log p_\theta(D=0|x_t,c_t)\right] \\
&=\sum_{t=1}^{k_d}\log p_\theta(D=1|x_t,c_t)+\sum_{t=1}^{k_n}\log p_\theta(D=0|x_t,c_t)\\
&=\sum_{t=1}^{k_d}\log \frac{u_\theta(x_t,c_t)}{u_\theta(x_t,c_t)+kq(x_t)}+\sum_{t=1}^{kn}\log \frac{kq(x_t)}{u_\theta(x_t,c_t)+kq(x_t)} \\
&\xrightarrow{/k_d}\mathbb E_{x_t\sim\tilde p(x|c)}\log\frac{u_\theta(x_t,c_t)}{u_\theta(x_t,c_t)+kq(x_t)}+k\mathbb E_{x_t\sim q(x)}\log \frac{kq(x_t)}{u_\theta(x_t,c_t)+kq(x_t)}
\end{aligned}
\tag{Eq.9}
\]
我们计算\(\mathcal J_{NCE}\)对网络参数\(\theta\)的梯度,计算结果如下:
\[\frac{\partial\mathcal J_{NCE}}{\partial\theta}=\sum_{x}\left[\frac{kq(x)}{p_\theta(x|c)+kq(x)}\left(\tilde p\left(x|c\right)-p_\theta\left(x|c\right) \frac{\partial}{\partial\theta}\log u_\theta(x,c)) \right)\right]
\tag{Eq.10}
\]
当负样本足够多,k足够大,趋于无穷时的极限如下
\[\lim\limits_{k\rightarrow\infty}\frac{\partial\mathcal J_{NCE}}{\partial\theta}=\sum_x\left[\left(\tilde p\left(x|c\right)-p_\theta\left(x|c\right)\right)\frac{\partial}{\partial\theta}\log u_\theta(x,c)\right]
\tag{Eq.11}
\]
我们可以计算(Eq.2)和(Eq.3)的原始问题的梯度大小,通过极大似然估计可以计算出这个概率,结果如下:
\[\mathcal J_{MLE}=\arg\max_\theta\sum_{x\sim\tilde p(x|c)}\log\frac{u_\theta(x,c)}{\sum_{x'\in V}u_\theta(x',c)}
\tag{Eq.12}
\]
则原始问题对网络参数\(\theta\)的梯度计算结果如下
\[\frac{\partial\mathcal J_{MLE}}{\partial\theta}=\sum_{x}\left[\left(\tilde p\left(x|c\right)-p_\theta\left(x|c\right) \frac{\partial}{\partial\theta}\log u_\theta(x,c)) \right)\right]
\tag{Eq.13}
\]
我们可以对比(Eq.10)和(Eq.13)两个结果,当\(k\rightarrow\infty\)时,NCE Loss与原始问题的极大似然估计结果一致,所以也说明了负样本数量在对比学习中的重要性。
More
NCE的升级版是InfoNCE,将二分类思想推广到多类别的情况,具体推荐阅读InfoNCE损失函数