DiffuSeq: Sequence to Sequence Text Generation with Diffusion Models
概
本文提出了一种用于 Seq2Seq 的不需要 classifier 引导的扩散模型, 且是在连续空间上讨论的.
虽然方法看起来很简单, 但是感觉很容易 work 和推广.
符号说明
- \(\mathbf{z}_0 \sim q(\mathbf{z})\), a real-world data distribution;
- \(\mathbf{z}_T \sim \mathcal{N}(\bm{0}, \mathbf{I})\), Gaussian noise;
- \(q(\mathbf{z}_t | \mathbf{z}_{t-1}) = \mathcal{N}(\mathbf{z}_t; \sqrt{1 - \beta_t} \mathbf{z}_{t-1}, \beta_t \mathbf{I}), \: t \in [1,2, \ldots, T]\);
- \(f_{\theta}\), a diffusion model;
- \(\mathbf{w}^x = [w_1^x, \ldots, w_m^x]\), m-length soure sequence (离散的);
- \(\mathbf{w}^y = [w_1^y, \ldots, w_n^y]\), n-length soure sequence (离散的).
流程
-
首先利用获取词的 embeddings:
\[\mathbf{z}_0 = \text{Emb}(\mathbf{w}) = [\text{Emb}(w_1), \text{Emb}(w_2), \ldots], \]这一步实际上是相当于构建从离散空间到连续空间的一个映射:
\[q_{\phi}(\mathbf{z}_0|\mathbf{w}) = \delta_{\text{Emb}(\mathbf{w})}(\mathbf{z}_0). \] -
因为整个流程设计两个部分: source \(x\), target \(y\), 不妨令
\[\mathbf{x}_0 = \text{Emb}(\mathbf{w}^x) = [\text{Emb}(w_1^x), \text{Emb}(w_2^x), \ldots], \\ \mathbf{y}_0 = \text{Emb}(\mathbf{w}^y) = [\text{Emb}(w_1^y), \text{Emb}(w_2^y), \ldots]. \]于是
\[\mathbf{z}_0 = \mathbf{x}_0 \oplus \mathbf{y}_0. \]类似的之后的 \(\mathbf{z}_t\) 均可以分为 source 和 target 两部分, 即
\[\mathbf{z}_t = \mathbf{x}_t \oplus \mathbf{y}_t. \] -
前向过程: 如上图所示:
- 根据 \(q_{\phi}(\mathbf{z}_0|\mathbf{w})\) 得到 \(\mathbf{z}_0\) (这一步实际上是确定的);
- 此时我们依旧在连续空间中了, 故我们可以使用一般的高斯分布来加噪, 即:\[\mathbf{z}_t' \sim q(\mathbf{z}_t | \mathbf{z}_{t-1}) = \mathcal{N}(\mathbf{z}_t; \sqrt{1 - \beta_t} \mathbf{z}_{t-1}, \beta_t \mathbf{I}). \]但是特别地, 我们只对 target 部分加噪:\[\mathbf{z}_t = \mathbf{x}_{0} \oplus \mathbf{y}_t'. \]
-
反向过程: 同样如上图所示:
-
从标准的高斯分布中采样 \(\mathbf{z}_T'\), 并令
\[\mathbf{z}_T = \mathbf{x}_0 \oplus \mathbf{y}_T'. \] -
根据如下分布进行反向传递:
\[\mathbf{z}_{t-1}' \sim \mathcal{N}(\mathbf{z}_{t-1}; \mu_{\theta}(\mathbf{z}, t), \sigma_{\theta}(\mathbf{z}_t, t)), \\ \mathbf{z}_{t-1} = \mathbf{x}_0 \oplus \mathbf{y}_{t-1}', \\ t \ge 2. \]
-
-
最后的损失为如下:
-
需要注意的是, 其中 \(q_{\phi}(\mathbf{z}_0|\mathbf{w}^{x \oplus y})\) 本身是一个确定的过程, 所以是不提供导数的, 可以省略. 整体的推导其实普通的 VLB 没什么差别, \(\mathcal{L}_{round}\) 也只是原来的损失一部分, 只是被作者单拎了出来. 不过也有道理, 因为但看它, 其实就是希望训练一个分类网络, 将 \(\mathbf{z}_0\) 映射回词.
-
不过作者最后用的也不是上面的损失, 而是一个简化的版本 (即把原先的系数给去掉后的结果):
\[\mathcal{L}_{\text{VLB}} = [\sum_{t=2}^T \|\mathbf{z}_0 - f_{\theta}(\mathbf{z}, t)\|^2 + \|\text{Emb}(\mathbf{w}^{x \oplus y}) - f_{\theta}(\mathbf{z}_1, 1)\|^2 - \log p_{\theta} (\mathbf{w}^{x \oplus y} | \mathbf{z}_0)] \\ \Rightarrow [\sum_{t=2}^T \|\mathbf{y}_0 - \tilde{f}_{\theta}(\mathbf{z}, t)\|^2 + \|\text{Emb}(\mathbf{w}^{y}) - \tilde{f}_{\theta}(\mathbf{z}_1, 1)\|^2 - \log p_{\theta} (\mathbf{w}^{x \oplus y} | \mathbf{z}_0)]. \] -
\(f, \tilde{f}\) 就是对 \(\mathbf{z}_t, \mathbf{y}_t\) 的直接拟合, 是另一种损失的写法. 具体看 here