理论-GAN2-训练GAN时所遇到的问题及可能的解决方法

问题1,模式坍塌(Mode collapse )

  • 对模式崩溃产生原因的猜想:

    • GAN的学习目标是映射关系G:x➡y,这种单一域之间的对应关系是高度约束不足的,无法为分类器和判别其的训练提供足够的信息输入。
    • 在这种情况下所优化得到的G可以将域X转换为与Y分布相同的域Y',但是并不能确保单独的输入和输出样本x和y是以一种有意义的方式配对的——无限多种映射G(由训练过程的随机性产生)针对单独的输入x可能产生无限多种y(对G的约束依旧不足,只是保证了分布域上的一致,而且每一个分布域都是由多子域(比如多个类别)所组成的,如果只是将x转换为Y中的其中一个子域(此时生成的y全部都是同一个类别),这样依旧可以使得损失函数降到最低,从而使得G和D达到局部最优并使训练停止,这也是模式崩溃产生的原因)。
    • 模式崩溃问题导致很难孤立地优化对抗性目标,经常发生所有输入都映射到相同输出并且优化无法取得进展的现象。(或许将每一个类别都单独挑出来分别生成独有的G是一种根治的方案,不过这样的G也不要想着有什么泛化能力了)。
  • spectral collapse(谱崩溃)& spectral regularization(频谱正则化)

    • 文献:Spectral regularization for combating mode collapse in GANs

    • 代码:https://github.com/max-liu-112/SRGANs-Spectral-Regularization-GANs-

    • 概念及定义:

      • 谱崩溃:当模式崩溃发生时,权重矩阵的奇异值急剧下降的现象称为谱崩溃,作者发现模式崩溃和频谱崩溃并存的现象普遍存在,而本文通过解决谱崩溃来解决模式崩溃问题。

      • 谱归一化权重矩阵(spectral nornalized weight matrix)\(\bar{W}_{SN}(W)\);当模型没有模式崩溃发生时,\(\bar{W}_{SN}(W)\)中大部分值接近1;而当模式崩溃发生时,\(\bar{W}_{SN}(W)\)中的值会急剧下降。(作者在文中做了一些实验,说明了这一现象,但没有从理论层面证明为什么会发生这一现象)

        \[\bar{W}_{SN}(W):=\frac{W}{\sigma(W)} \]

        其中\(\sigma(W)\)是D中权重矩阵的[谱范数](https://mathworld.wolfram.com/SpectralNorm.html#:~:text=Spectral Norm. The natural norm induced by the,root of the maximum eigenvalue of %2C i.e.%2C)(WolframMathWorld-一个神奇的网站),相当于权重矩阵的最大奇异值。

      • 权重矩阵\(W\)的奇异值分解(singular value decomposition)(原文中的公式)

        \[W=U\cdot\sum\cdot{V^T} \]

        \[U^TU=I \]

        \[V^TV=I \]

        \[\sum=[\begin{matrix}D&0\\0&0\end{matrix}] \]

        其中\(D=diag{\{\sigma_1,\sigma_2,\cdots,\sigma_r\}}\)

    • 解决方案:频谱正则化通过补偿频谱分布避免频谱崩溃,从而对D的权重矩阵施加约束(核心思想:防止D的权重矩阵W集中到一个方向上)。有两种频谱正则化方案,

      • 频谱正则化(spectral regularization)

        • 静态补偿(static compensation):需要手动确定超参数\(i\),不易于应用。

          \[\Delta{D}=\left[\begin{matrix}\sigma_1-\sigma_1 & 0 & \cdots & \cdots & \cdots & 0 \\ 0 & \sigma_1-\sigma_2 & \cdots & \cdots & \cdots & 0 \\ \vdots & \cdots & \ddots & \cdots & \cdots & 0\\ \vdots & \cdots & \cdots & \sigma_1-\sigma_i & \cdots & 0\\ \vdots & \cdots & \cdots & \cdots & \ddots & 0\\ 0 & \cdots & \cdots & \cdots & \cdots & 0 \end{matrix}\right] \]

        • 动态补偿(dynamic compensation):没有需要手动确定的超参数,相比于静态补偿使用起来更方便。

          \[\Delta{D^T}=\left[\begin{matrix}0 & 0 & \cdots & 0 \\ 0 & \gamma_2^T\cdot{\sigma_1^T-\sigma_2^T} & \cdots & 0 \\ \vdots & \cdots & \ddots & 0\\ 0 & 0 & \cdots & \gamma_r^T\cdot{\sigma_1^T-\sigma_r^T} \end{matrix}\right] \]

          \(\Delta{D^T}\)是第\(T\)次迭代的补偿矩阵,\(\gamma_j^T\)是第\(j\)个补偿系数:

          \[\gamma_j^T=max(\frac{\sigma_j^1}{\sigma_1^1},\cdots,\frac{\sigma_j^t}{\sigma_1^t},\cdots,\frac{\sigma_j^T}{\sigma_1^T}),t=0,1,\cdots,T \]

          \(\sigma_j^t\)是第\(t\)次迭代的第\(j\)个奇异值。

      • 频谱正则化的实现

        \[\Delta{W}=U\cdot{[\begin{matrix}\Delta{D}&0\\0&0\end{matrix}]}\cdot{V^T} \]

        \[\bar{W}_{SR}(W)=\frac{W+\Delta{W}}{\sigma(W)}=\bar{W}_{SN}(W)+\frac{\Delta{W}}{\sigma(W)} \]

  • implicit variational learning(隐式变分学习)

    • 文献:VEEGAN: Reducing Mode Collapse in GANs using Implicit Variational Learning

    • 代码:https://github.com/akashgit/VEEGAN/blob/master/VEEGAN_2D_RING.ipynb

    • 概念及定义:

      • 隐式变分原理(implicit variational principle):

        VEEGAN引入了一个额外的重构网络(reconstructor network),将真实数据映射到高斯随机噪声,通过联合训练训练生成器和重建器网络鼓励重建器网络不仅将数据分布映射到高斯分布,而且还近似地反转生成器的动作。

      • 如何理解使用隐式变分原理可以防止模式崩溃?

        • 观察上图:中部\(p(x)\)是由两个高斯分布叠加而成的真实分布;底部\(p_0(z)\sim{N(0,1)}\)是生成器\(G_\gamma\)的输入;顶部是将重构网\(F_\theta\)用于生成数据和真实数据的结果;由底部到中部的箭头是生成器\(G_\gamma\)的动作;由中部到顶部的绿色箭头是重构生成数据的动作,紫色箭头是重构真实数据的动作。在图中,生成器都只是捕获了\(p(x)\)中其中一个高斯分布,图a与图b的区别在于重构网络不同。
        • 图a中\(F_\theta\)\(G_\gamma\)的逆,由于生成数据只包含真实数据的部分分布,\(F_\theta\)对真实数据中分布被丢失的那部分数据的处理结果不定,这也意味着其重构结果大概率与\(p_0(z)\)不匹配,这种不匹配可以作为模式崩溃的指标。
        • 图b中\(F_\theta\)成功将真实数据重构回\(p_0(z)\),此时如果\(G_\gamma\)发生模式崩溃,\(F_\theta\)并不会将生成数据重构回\(p_0(z)\)(毕竟真实数据分布与生成数据分布存在差异),由此产生的惩罚信息提供了强大的\(G_\gamma\),\(F_\theta\)学习信息。
      • 文中提到了一个模式崩溃发生原因的猜想:目标函数提供的关于生成器网络参数\(\gamma\)的唯一信息是由鉴别器网络\(D_\omega\)介导的。(An intuition behind why mode collapse occurs is that the only information that the objective function provides about γ is mediated by the discriminator network Dω)

      • 重构网络本质上是依据重构数据的差异反应生成数据和真实数据的差异,那为什么不直接度量生成数据和真实数据的分布差异呢?为什么必须要借助重构网络呢?

    • 解决方案:

      • 重构损失

        \[\min_{\gamma,\theta}O_{entropy}(\gamma,\theta)=E[||z-F_{\theta}(G_\gamma(z))||_2^2]+H(Z,F_\theta(X))~~~~~~\tag{1} \]

        前半部分保证\(F_\theta\)\(G_\gamma\)的逆;后半部分保证对于真实数据,\(F_\theta\)的重构结果依旧是与\(p_0(z)\)相同的分布,使用交叉熵进行计算。

      • 为了便于计算,将重构损失进行如下转换:

        重构网络\(F_\theta(x)\)对应于分布\(p_{\theta}(z|x)\),样本集合\(X\sim{p(x)}\)的平均重构数据为

        \[p_{\theta}(z)=\int{p_{\theta}(z|x)p(x)dx}~~~~~~\tag{2} \]

        根据交叉熵公式以及2式,\(H(Z,F_{\theta}(X))\)可写作

        \[H(Z,F_{\theta}(X))=-\int{p_{0}(z)logp_{\theta}(z)dz}=-\int{p_0(z)}log\int{p(x)p_{\theta}(z|x)}dxdz~~~~~~\tag{3} \]

        \(p_\theta(z)=p_0(z)\)时交叉熵最小,为了使上式可计算(毕竟\(p(x)\)未知),引入变分分布\(q_\gamma(x|z)\)和Jensen不等式,有(推导看原文):

        \[-logp_\theta(z)=-log\int{p_\theta(z|x)p(x)\frac{q_\gamma(x|z)}{q_\gamma(x|z)}}dx\leq{\int{q_\gamma(x|z)log\frac{p_\theta(z|x)}{q_\gamma(x|z)}}}dx~~~~~~\tag{4} \]

        \[-\int{p_0(z)}logp_\theta(z)\leq{KL[q_\gamma(x|z)p_0(z)||p_\theta(z|x)p(x)]}-E[logp_0(z)]~~~~~~\tag{5} \]

        这里的\(q_\gamma(x|z)\)对应生成器,\(p_{\theta}(z|x)\)对应重构器,由1和5式可将优化目标转化为(此为优化目标的上界):

        \[O_{entropy}(\gamma,\theta)=E[||z-F_{\theta}(G_\gamma(z))||_2^2]+KL[q_\gamma(x|z)p_0(z)||p_\theta(z|x)p(x)]-E[logp_0(z)]~~~~~~\tag{6} \]

        6式还是无法计算,因为\(q_\gamma(x|z)\)对应生成器,\(p_{\theta}(z|x)\)对应重构器,都是隐式表示,分布未知;样本数据分布\(p(x)\)也是未知,这里假设训练所得判别器\(D_\omega(x,z)\)满足

        \[D_\omega(Z,X)=log\frac{q_\gamma(x|z)p_0(z)}{p_\theta(z|x)p(x)}\tag{7} \]

        并有

        \[\hat{O}(\omega,\gamma,\theta)=\frac{1}{N}\sum_{i=1}^{N}D_{\omega}(z^i,x^i_g)+\frac{1}{N}\sum_{i=1}^{N}d(z^i,x_g^i) \tag{8} \]

        其中\((z^i,x_g^i)\sim{p_0(x)q_\gamma(x|z)}\),优化目标最终化为:

        \[O_{LR}(\omega,\gamma,\theta)=-E_\gamma[log(\sigma(D_{\omega}(z,x)))]-E_{\theta}[log(1-\sigma({D_{\omega}(z,x)}))] \tag{9} \]

        训练伪代码:生成器\(\gamma\),重构器\(\theta\),判别器\(\omega\)

posted @ 2021-07-05 15:30  tensor_zhang  阅读(1388)  评论(0编辑  收藏  举报