Structured Denoising Diffusion Models in Discrete State-Spaces

Austin J., Johnson D. D., Ho J., Tarlow D. and van den Berg R. Structured denoising diffusion models in discrete state-spaces. In Advances in Neural Information Processing Systems (NIPS), 2021.

DPM 在离散空间上的探讨.

符号说明

  • forward process: \(q(\bm{x}_{1:T}|\bm{x}_0) = \prod_{t=1}^T q(\bm{x}_t|\bm{x}_{t-1}), \bm{x}_0 \sim q(\bm{x}_0)\);
  • reverse process: \(p_{\theta}(\bm{x}_{0:T}) = p(\bm{x}_T) \prod_{t=1}^T p_{\theta} (\bm{x}_{t-1}|\bm{x}_t)\).
  • 这里向量 \(\bm{x}\) 都表示行向量.

Motivation

  • DDPM 已经在图像生成领域取得了巨大成功, 它分为上述的前向和后向过程, 并通过如下损失进行优化:

    \[\tag{1} \begin{array}{l} L_{\text{vb}} = \mathbb{E}_{q(\bm{x}_0)} \Bigg [ \underbrace{D_{KL}[q(\bm{x}_T|\bm{x}_0 \| p(\bm{x}_T))]}_{L_T} + \sum_{t=2}^T \underbrace{\mathbb{E}_{q(\bm{x}_t|\bm{x}_0)} [D_{KL} [q(\bm{x}_{t-1}|\bm{x}_t, \bm{x}_0)\| p_{\theta}(\bm{x}_{t-1}|\bm{x}_t)]]}_{L_{t-1}} \\ \quad \quad \quad \quad \quad \quad \underbrace{-\mathbb{E}_{q(\bm{x}_1|\bm{x}_0)}[\log p_{\theta}(\bm{x}_0| \bm{x}_1)]}. \Bigg] \end{array} \]

  • 之前的扩散模型主要研究连续空间的情形 (故而通常采用比较好采样的高斯分布), 那么如何把 DPM 推广到离散空间中呢.

  • 首先, 合适的前向扩散过程需要满足:

    1. \(q(\bm{x}_t|\bm{x}_0)\) 对任意的时间戳 \(t\) 的采样都是容易的;
    2. \(q(\bm{x}_{t-1}|\bm{x}_t, \bm{x}_0)\) 最好是有容易求解的显示表达, 这样有利于求解 \(L_{t-1}\).

基于转移概率矩阵的 D3PM

  • 假设离散空间中有 \(K\) 个元素, 记为 \(1,2, \ldots, K\), 则对于任意的时间戳 \(t\) 的状态 \(x_t\) 都可以取遍这些元素. 我们知道, DPM 通常假设前向以及反向过程都是满足马氏性的, 而马氏性通常可以用转移概率矩阵 \(Q\) 来描述.

  • \(Q_t \in \mathbb{R}^{K \times K}\), \([Q_t]_{ij} = q(x_t = j | x_{t-1} = i)\), 则

    \[q(x_t = j|x_{t-1} = i) = [\underbrace{0, 0, \cdots, 0}_{i-1}, 1, 0, \cdots, 0] Q_t [\underbrace{0, 0, \cdots, 0}_{j-1}, 1, 0, \cdots, 0]^T. \]

  • 不妨假设 \(\bm{x}_t \in \mathbb{R}^{1 \times K}\) 为一 one-hot 向量 (注意和之前的定义区分), 即 \([\underbrace{0, 0, \cdots, 0}_{j-1}, 1, 0, \cdots, 0]\), 则,

    \[q(\bm{x}_t| \bm{x}_{t-1}) = \bm{x}_{t-1} Q_t \bm{x}_t^T \sim \text{Cat}(\bm{x}_t; \bm{p} = \bm{x}_{t-1}Q_t). \]

  • 进而注意到

    \[\begin{array}{ll} q(\bm{x}_t|\bm{x}_0) &= \sum_{\bm{x}_{1:t-1}} \prod_{k=1}^{t} q(\bm{x}_k|\bm{x}_{k=1}) \\ &= \sum_{\bm{x}_{1:t-1}} \prod_{k=1}^{t} \bm{x}_{k-1}Q_k \bm{x}_{k}^T \\ &= \sum_{\bm{x}_{1:t-1}} \bm{x}_0 Q_1 \bm{x}_1^T \cdots \bm{x}_{k-1}Q_k \bm{x}_{k}^T \cdots \bm{x}_{t-1} Q_t \bm{x}_t^T\\ &= \bm{x}_0 Q_1 (\sum_{\bm{x}_{1}} \bm{x}_1^T \bm{x}_1) \cdots (\sum_{\bm{x}_{k-1}} \bm{x}_{k-1}^T \bm{x}_{k-1}) Q_k (\sum_{\bm{x}_k}\bm{x}_{k}^T \bm{x}_k) \cdots (\sum_{\bm{x}_{t-1}}\bm{x}_{t-1}^T \bm{x}_{t-1}) Q_t \bm{x}_t^T\\ &= \bm{x}_0 Q_1 I Q_2 \cdots I Q_k I \cdots I Q_t \bm{x}_t^T\\ &= \bm{x}_0 \underbrace{Q_1 Q_2 \cdots Q_k \cdots Q_t}_{\bar{Q}_t} \bm{x}_t^T\\ &= \bm{x}_0 \bar{Q}_t \bm{x}_t^T \sim \text{Cat}(\bm{x}_t; \bm{p} = \bm{x}_{0}\bar{Q}_t). \end{array} \]

  • 以及

    \[\begin{array}{ll} q(\bm{x}_{t-1}|\bm{x}_t, \bm{x}_0) &= \frac{q(\bm{x}_t|\bm{x}_{t-1}, \bm{x}_0) q(\bm{x}_{t-1}|\bm{x}_0)}{q(\bm{x}_t|\bm{x}_0)} \\ &= \frac{q(\bm{x}_t|\bm{x}_{t-1}) q(\bm{x}_{t-1}|\bm{x}_0)}{q(\bm{x}_t|\bm{x}_0)} \\ &= \frac{\bm{x}_{t-1}Q_t \bm{x}_t^T \cdot \bm{x}_0 \bar{Q}_{t-1} \bm{x}_{t-1}^T}{\bm{x}_0 \bar{Q}_t \bm{x}_{t}^T} \\ &= \frac{(\bm{x}_{t}Q_t^T \bm{x}_{t-1}^T) \cdot (\bm{x}_0 \bar{Q}_{t-1} \bm{x}_{t-1}^T)}{\bm{x}_0 \bar{Q}_t \bm{x}_{t}^T} \\ &= \frac{(\bm{x}_{t}Q_t^T) \odot (\bm{x}_0 \bar{Q}_{t-1}) (\bm{x}_{t-1}^T)}{\bm{x}_0 \bar{Q}_t \bm{x}_{t}^T} \: \leftarrow (?) \\ &\sim \text{Cat}(\bm{x}_{t-1}; \bm{p} = \frac{(\bm{x}_{t}Q_t^T) \odot (\bm{x}_0 \bar{Q}_{t-1})}{\bm{x}_0 \bar{Q}_t \bm{x}_{t}^T} ). \end{array} \]

    需要注意的是, (?) 处成立完全是因为 \(\bm{x}_t\) 是 one-hot 的向量, 否则一般情况下是不成立的.

  • 现在, 我们以及搭建好了基于转移概率矩阵的一个整体框架, 下面我们就需要讨论如何设计 \(Q_t\) 使得 \(\bar{Q}_t\) 是易求得. 此外, \(Q_t\) 的行和应当为 \(1\), 且 \(\bar{Q}_t, t \rightarrow +\infty\) 应当是收敛的.

转移概率矩阵的设计

  • \(Q\) 是 double stochastic 的时候, 即 行和, 列和 均为1, 此时倘若 \(Q_t = Q\), 则该分布必收敛到均匀分布:

    \[[\bm{\pi}Q]_j = \sum_{i=1}^K \frac{1}{K} [Q]_{ij} = \frac{1}{K} = \pi_j. \]

接下来, 我们具体介绍几种设计.

Uniform diffusion

  • 构造:

    \[[Q_t]_{ij} = \left \{ \begin{array}{ll} 1 - \frac{K - 1}{K} \beta_t & \text{ if } j = i, \\ \frac{1}{K} \beta_t & \text{ if } j \not = i. \end{array} \right . \]

  • 用人话说就是, 知道了上一个状态 \(x_{t-1}\), \(x_t\) 仍为 \(x_{t-1}\) 的概率为 \(1 - \frac{K-1}{K} \beta_t\), 否则它等概率地成为其它状态.

  • 矩阵表示为

    \[Q_t = (1 - \beta_t) I + \beta_t \bm{1}^T \bm{1} / K. \]

  • 容易发现:

    \[\bar{Q}_t = Q_{t}Q_{t-1} \cdots Q_1 = \prod_{k=1}^t (1 - \beta_k) I + (1 - \prod_{k=1}^t (1 - \beta_k)) \bm{1}^T \bm{1} / K. \]

Diffusion with an absorbing state

  • 构造

    \[[Q_t]_{ij} = \left \{ \begin{array}{ll} 1 & \text{ if } j=i=m \\ 1 - \beta_t & \text{ if } j = i \not= m, \\ \beta_t & \text{ if } j = m, i \not= m. \end{array} \right . \]

    这里 \(m\) 是一个特殊的吸收态, 比如语言模型里的 [mask].

  • 用人话说就是,

    1. 如果 \(x_{t-1}\) 已经是 \(m\) 了, 那么 \(x_t\) 必为 \(m\);
    2. 如果 \(x_{t-1}\) 不为 \(m\), 则 \(x_t\)\(\beta_t\) 的概率成为 \(m\), 否则仍为 \(x_{t-1}\).
  • 矩阵表示为

    \[Q_t = (1 - \beta_t) I + \beta_t \bm{1}^T \bm{e}_m, \]

    这里 \(\bm{e}_m\) 表示第 \(m\) 个元素为 1, 其余均为 0 的行向量.

  • 显然, 这不是一个 double stochastic 的转移矩阵.

  • 容易发现:

    \[\bar{Q}_t = Q_{t}Q_{t-1} \cdots Q_1 = \prod_{k=1}^t (1 - \beta_k) I + (1 - \prod_{k=1}^t (1 - \beta_k)) \bm{1}^T \bm{e}_m. \]

注: 作者还介绍了如何把高斯离散化, 以及利用 embedding 的相似度来刻画转移矩阵的方法, 这里不作介绍了.

注: 作者对于 noise schedule 的一些见解非常有趣, 但是说实话, 我只能看懂一半, 写在这里就不懂装懂了.

代码

[official]

posted @ 2022-12-14 19:54  馒头and花卷  阅读(1302)  评论(0编辑  收藏  举报