Cache-Augmented Inbatch Importance Resampling for Training Recommender Retriever

Chen J., Lian D., Li Y., Wang B., Zheng K. and Chen E. Cache-augmented inbatch importance resampling for training recommender retriever. In Advances in Neural Information Processing Systems (NIPS), 2022.

作者通过 batch 内的一个重采样来逼近真是的一个分布, 加上在历史'重要'样本中进行重采样来强调那些 hard 的负样本.

符号说明

  • \(\mathcal{I}, |\mathcal{I}| = N\), item 的集合;
  • \(\{\bm{q}_i \in \mathbb{R}^{d_u}\}_{i=1}^M\), queries;
  • \(\{\bm{e}_i \in \mathbb{R}^{d_i}\}_{i=1}^N\), items;
  • \(\phi_{Q}: \mathbb{R}^{d_u} \rightarrow \mathbb{R}^{d_k}\), 将 query 映射为 k 维特征;
  • \(\phi_{I}: \mathbb{R}^{d_i} \rightarrow \mathbb{R}^{d_k}\), 将 item 映射为 k 维特征;
  • \(s(u, i) = \langle \phi_Q(\bm{q}_u), \phi_I(\bm{e}_i) \rangle\), score;

启发

  1. 有了 \(s(u, i), i \in \mathcal{I}\), 我们可以估计用户 \(u\) 在一堆 \(\mathcal{I}\) 中选择 item \(i\) 的概率:

    \[P(i|u) = \frac{\exp(s(u, i))}{\sum_{j \in \mathcal{I}} \exp(s(u, j))}; \]

  2. 并通过如下损失进行训练:

    \[\tag{1} \mathcal{L}_{\text{softmax}} (\mathcal{D}, \Theta) = -\frac{1}{|\mathcal{D}|} \sum_{(u, i) \in \mathcal{D}} \log P(i|u); \]

  3. 但是处于计算复杂度的限制, 我们通常会选择如下的一个方案:

    \[\tag{2} \mathcal{L}_{\text{sampled\_softmax}} (\mathcal{D}, \Theta) = -\frac{1}{|\mathcal{D}|} \sum_{(u, i) \in \mathcal{D}} \log \frac{\exp(s'(u, i))}{\sum_{j \in \mathcal{S_u}} \exp(s'(u, j))} \]

    其中 \(\mathcal{S}\) 是一个采样的子集, \(s'(u, i) := s(u,i) - \log p(i|u)\) 是经过校正后的 score, 至于为什么校正, 请看 here.

本文方法

所以本文就是讨论如何使得 (2) 逼近 (1) 甚至做的更好.

BIR (inbatch importance resampling)

算法如下:

  1. 获得一个当前的 batch \(B\), 计算如下的权重

    \[w(i|u) = \frac{\exp(s(u, i) - \log pop(i))}{\sum_{j \in B} \exp(s(u, j) - \log pop(j))}; \]

  2. \(\{w(i|u): i \in B\}\) 中重新采样得到集合 \(\mathcal{R}_u\), 然后通过如下损失进行训练:

    \[\tag{3} \mathcal{L}_{\text{BIR}} (B, \Theta) = -\frac{1}{|B|} \sum_{(u, i) \in B} \log \frac{\exp(s(u, i))}{\sum_{j \in \mathcal{R_u}} \exp(s(u, j))}. \]

作者说当 \(|B| \rightarrow +\infty\) 的时候, (3) 等价于 (1) 的一个 mini-batch, 即

\[P(i \in \mathcal{R}_u) \approx P(i|u), \]

我感觉这个证明应该是错的. 举个反例, 假设 \(B = \mathcal{D}\), 此时

\[P(i \in \mathcal{R}_u) = 1 - [1 - w(i|u)]^{|\mathcal{D}|}. \]

有可能是我误会作者的证明或者采样方式了 (感觉作者写这篇文章比较赶, 里面有很多符号的错误), 如果有谁知道麻烦告诉我一声.

XIR (Cache-Augmented Resampling)

  • 作者认为, 如果一个样本被频繁采样, 那么它应该是很重要的才对, 所以这里作者保留了这样一部分样本: \(\mathcal{C}\);

  • 每次优化的时候, 分别从 \(\mathcal{C}\)\(\mathcal{R}_u\) 中采样并优化, 如下所示:

    \[\tag{4} \mathcal{L}_{\chi\text{IR}} (\mathcal{D}, \Theta) = -\lambda \sum_{(u, i) \in B} \log \frac{\exp(s(u, i))}{\sum_{j \in \mathcal{\mathcal{K}_u}} \exp(s(u, j))} -(1 - \lambda)\sum_{(u, i) \in B} \log \frac{\exp(s(u, i))}{\sum_{j \in \mathcal{R_u}} \exp(s(u, j))}. \]

注: (3) (4) 中作者都是不带 \(-\log\) 的, 应该是笔误吧.

posted @ 2022-09-21 15:19  馒头and花卷  阅读(116)  评论(0编辑  收藏  举报