DiffuSeq: Sequence to Sequence Text Generation with Diffusion Models

Gong S., Li M., Feng J., Wu Z. and Kong L. DiffuSeq: Sequence to sequence text generation with diffusion models. In International Conference on Learning Representations (ICLR), 2023

本文提出了一种用于 Seq2Seq 的不需要 classifier 引导的扩散模型, 且是在连续空间上讨论的.
虽然方法看起来很简单, 但是感觉很容易 work 和推广.

符号说明

  • \(\mathbf{z}_0 \sim q(\mathbf{z})\), a real-world data distribution;
  • \(\mathbf{z}_T \sim \mathcal{N}(\bm{0}, \mathbf{I})\), Gaussian noise;
  • \(q(\mathbf{z}_t | \mathbf{z}_{t-1}) = \mathcal{N}(\mathbf{z}_t; \sqrt{1 - \beta_t} \mathbf{z}_{t-1}, \beta_t \mathbf{I}), \: t \in [1,2, \ldots, T]\);
  • \(f_{\theta}\), a diffusion model;
  • \(\mathbf{w}^x = [w_1^x, \ldots, w_m^x]\), m-length soure sequence (离散的);
  • \(\mathbf{w}^y = [w_1^y, \ldots, w_n^y]\), n-length soure sequence (离散的).

流程

  • 首先利用获取词的 embeddings:

    \[\mathbf{z}_0 = \text{Emb}(\mathbf{w}) = [\text{Emb}(w_1), \text{Emb}(w_2), \ldots], \]

    这一步实际上是相当于构建从离散空间到连续空间的一个映射:

    \[q_{\phi}(\mathbf{z}_0|\mathbf{w}) = \delta_{\text{Emb}(\mathbf{w})}(\mathbf{z}_0). \]

  • 因为整个流程设计两个部分: source \(x\), target \(y\), 不妨令

    \[\mathbf{x}_0 = \text{Emb}(\mathbf{w}^x) = [\text{Emb}(w_1^x), \text{Emb}(w_2^x), \ldots], \\ \mathbf{y}_0 = \text{Emb}(\mathbf{w}^y) = [\text{Emb}(w_1^y), \text{Emb}(w_2^y), \ldots]. \]

    于是

    \[\mathbf{z}_0 = \mathbf{x}_0 \oplus \mathbf{y}_0. \]

    类似的之后的 \(\mathbf{z}_t\) 均可以分为 source 和 target 两部分, 即

    \[\mathbf{z}_t = \mathbf{x}_t \oplus \mathbf{y}_t. \]

  • 前向过程: 如上图所示:

    1. 根据 \(q_{\phi}(\mathbf{z}_0|\mathbf{w})\) 得到 \(\mathbf{z}_0\) (这一步实际上是确定的);
    2. 此时我们依旧在连续空间中了, 故我们可以使用一般的高斯分布来加噪, 即:

      \[\mathbf{z}_t' \sim q(\mathbf{z}_t | \mathbf{z}_{t-1}) = \mathcal{N}(\mathbf{z}_t; \sqrt{1 - \beta_t} \mathbf{z}_{t-1}, \beta_t \mathbf{I}). \]

      但是特别地, 我们只对 target 部分加噪:

      \[\mathbf{z}_t = \mathbf{x}_{0} \oplus \mathbf{y}_t'. \]

  • 反向过程: 同样如上图所示:

    1. 从标准的高斯分布中采样 \(\mathbf{z}_T'\), 并令

      \[\mathbf{z}_T = \mathbf{x}_0 \oplus \mathbf{y}_T'. \]

    2. 根据如下分布进行反向传递:

      \[\mathbf{z}_{t-1}' \sim \mathcal{N}(\mathbf{z}_{t-1}; \mu_{\theta}(\mathbf{z}, t), \sigma_{\theta}(\mathbf{z}_t, t)), \\ \mathbf{z}_{t-1} = \mathbf{x}_0 \oplus \mathbf{y}_{t-1}', \\ t \ge 2. \]

  • 最后的损失为如下:

  • 需要注意的是, 其中 \(q_{\phi}(\mathbf{z}_0|\mathbf{w}^{x \oplus y})\) 本身是一个确定的过程, 所以是不提供导数的, 可以省略. 整体的推导其实普通的 VLB 没什么差别, \(\mathcal{L}_{round}\) 也只是原来的损失一部分, 只是被作者单拎了出来. 不过也有道理, 因为但看它, 其实就是希望训练一个分类网络, 将 \(\mathbf{z}_0\) 映射回词.

  • 不过作者最后用的也不是上面的损失, 而是一个简化的版本 (即把原先的系数给去掉后的结果):

    \[\mathcal{L}_{\text{VLB}} = [\sum_{t=2}^T \|\mathbf{z}_0 - f_{\theta}(\mathbf{z}, t)\|^2 + \|\text{Emb}(\mathbf{w}^{x \oplus y}) - f_{\theta}(\mathbf{z}_1, 1)\|^2 - \log p_{\theta} (\mathbf{w}^{x \oplus y} | \mathbf{z}_0)] \\ \Rightarrow [\sum_{t=2}^T \|\mathbf{y}_0 - \tilde{f}_{\theta}(\mathbf{z}, t)\|^2 + \|\text{Emb}(\mathbf{w}^{y}) - \tilde{f}_{\theta}(\mathbf{z}_1, 1)\|^2 - \log p_{\theta} (\mathbf{w}^{x \oplus y} | \mathbf{z}_0)]. \]

  • \(f, \tilde{f}\) 就是对 \(\mathbf{z}_t, \mathbf{y}_t\) 的直接拟合, 是另一种损失的写法. 具体看 here

代码

official

posted @ 2023-03-04 11:39  馒头and花卷  阅读(373)  评论(0编辑  收藏  举报