SAGAN 4/21-4/23 踩坑日记

    这几天为了毕业设计项目,忙的不行。因为本人毕设是关于医学图像数据增强算法方面的研究,近期主要是阅读并复现了一篇论文。论文的名称是SAGAN。下面,我将论文和代码进行简要的介绍,主要说的是其中的坑点!

Self-Attention Generative Adversarial Networks Han Zhang, Ian J. Goodfellow Published 30 March 2017 (Citations 1895)

    近些年来,随着人工智能,深度学习的不断发展,一大批一大批的网络框架也随之兴起,比如说大名鼎鼎的CNNRNN系列。但是大多数的进展往往是围绕着判别模型 Discriminator 而言,对于生成模型 Generator 的网络效果却并没有那么好。

    2014年,生成对抗网络之父 Ian 提出了一个全新的网络模型框架,令人耳目一新,而这个网络模型框架就是生成对抗网络 GAN。它和一般的网络架构不同,该模型同时含有一个生成器Generator,和一个判别器Discriminator,两个模型相互拮抗、共同发展。

    虽然说GAN网络结构提出,并且取得一定的效果,但是他的弊端也很严重,主要体现在以下方面。

  1. 神经网络难以训练,难以收敛;
  2. 网络训练易出现mode collapse问题,模式崩溃;
  3. 网络的衡量指标难以定义;

首先,是网络难以收敛的问题,该问题很好理解。在平时的网络调参训练过程中,仅仅是判别器的调参都较为难调,何况在这里是存在两个网络生成器G,和判别器D呢?
第二,mode collapse问题,指的是Generator 生成的图片过于单一,多样性差
第三,生成图片的质量往往具有主观性,很难被及其所衡量。

当然本文的重点不是介绍上面这几种问题,本是在于对SAGAN的介绍

SAGAN 论文介绍与踩坑

     首先SAGAN论文是为了解决在原本图像生成任务中,尤其是对于含有大量结构信息、几何形状图像的生成。在原本的图像生成办法中,往往会存在一些问题,比如说生成的狗狗只有三条腿、甚至还可能少了一条等等。Ian认为,这是因为,像素和像素之间关联可能太小,导致不协调的问题。因此引入了大名鼎鼎的 self attention机制,来增强图像之间的关联程度。

     同时,为了提高 GAN 训练过程的稳定性,文中重点使用了两种方法,TTURSpectral Normalization。 下面,我对该内容进行介绍。

self-attention

     假设,大家对 self attention 机制已经有了一定的了解,我对 self-attention内容进行简要的回顾。举一个简单的 multi-head attention 多头注意力, 如图1所示。 注意力机制的思想,就是给定你n个 query 查询,然后你自身还有 mkey-value 键值对,然后给出 query 对应的数值。
     当然,一个很简单的做法。将 QueryKey 映射到同一维度,然后计算 某个 query 和所有的 key 之间的相似度,然后加权到 value上。这就是 attention,关键点就是在于权重。 multi-head 指的就是你有很多个attention,然后将他们 concat 起来就好了。看上去很像是我们 CNN中卷积的通道数,这里一个 head 就是一个通道。


图1

    那么,我们的 SAGAN 中的 Attention 是什么样子呢?借用论文中的一张图。我们先不考虑Batch_size 大小。
    输入一张特征图, 形状为(c, h, w),将其作为 query, key, value 进行 1X1 conv 映射,得到 f(x), g(x), h(x), 之后 f(x) 转置后,和 g(x) 进行矩阵乘,得到 softmax 操作 query 和 key 的权重,然后和 value 进行矩阵乘法,得到加权的 value_sum,也就是我们 query 之后的结果,最后再次进行 conv 1X1 操作即可。按照图来看,可以看出的进行什么计算,但是感觉他这样运算毫无道理,下面我来一点一点的讲解。

    想一想,我们作者使用 attention 的目的是什么,在于寻找像素和像素的关联,也就是说按道理,我们 query 应该是的正常形状为 (pix_num, c),pix_num表示的是像素的数量=width*height, c是通道数。 同样因为是self-attention key,value的大小形状和 query一样。
    那么我们下一步,直接计算 query和key的关联度,得到未进行softmax前的attention权重即可。那很简单,query(pix_num, c) 和 \(\rm Key^T\) (c, pix_num)进行矩阵乘法,得到矩阵\(\rm S, S_{i, j}\)表示\(\rm query_i 和value_j\)对的点积,将其作为query和value的相似度,也就是权重,然后进行softmax操作。 这里的话,你可能有点因为,为什么图中是 f(x)或者是说query 进行tranpose了呢?这不是说的有些矛盾了么? 用过 Pytorch 的话,你会明白,原始的 query形状为(c, pix_num),也就是说,咱们上文中的 query和key,已经是转置后的了。我认为这是原文的一大坑点,因为他并没有明确的给出,我们的attention对象应该是像素,虽然你可以get到,应该是pixel
    softmax操作,操作对象是最后一维,因为这代表某一个 query,他所有 key的权重和为1。剩下的就是简单的将权重加到 value 上,然后最后 Residual 操作了,从图中就很好理解了。


图2

    SAGAN论文,对上图给出了数学公式,不过说实话,这个公式也有些令人头疼。(因为我懒的用 LaTeX写了,直接给截图)


    首先要吐槽的一点\(x_i\)为什么是列向量,第一次看公式的我直接头都大了,这里的\(x_i\)和编程习惯完全不同,他指的是数学上列向量优先的列向量。
    然后,计算到 \(S_{i,j}\)我可以理解,毕竟是我刚刚解释过的 query和key的加权,但是不幸的是 softmax 之后,\(\beta\) 居然转置了。。。这一部分公式写的很古怪,甚至我事后写完代码也觉着很古怪,我觉着借助代码理解起来会好很多。

TTUR

    从实践的角度上,GAN训练过程中,经过正则化的判别器Discriminator需要多次更新,当Generator更新一次的时候。本文使用的方法是,提高 Discriminator 的Learning Rate,以达到类似的效果。而且相比更新一次G,更新K次D,TTUR的计算量更小,实践起来也很简单,对优化器的参数Learning Rate进行调整即可

Spectral Normalization

     在GAN的经典论中 WGAN 中,提出了 Wasserstein Distance 从损失函数的方面减缓 GAN训练不稳定的问题。

\(W(P_r, P_g)=\sup\limits_{Lip~1} E_{x-p_r}[f(x)]-E_{x-p_g}[f(x)] \dots\dots\dots(1)\)
其中,Lipschitz约束简单而言就是:要求在整个f(x)的定义域内有\(\frac{||f(x)-f(x')||}{||x-x'||}\le {\rm M}\)

但是,但是拉普拉斯平滑问题解决的方法却并不高明,对权重进行 clip 操作,使得容易发生梯度消失的现象。
对此WGAN-GP,提出了 Grapdient Penalty 的解决方法

本文选取的是另一种解决方法,来自于 SNGAN 这篇论文。
首先,我们考虑一层卷积、或者是线性层,加上激活函数ReLU,可以写为:
\(x_{i} = a_i(W_ix_{i-1}+b_i)\),输入为\(x_{i-1}\),输出为\(x_i\),激活函数为\(a_i\),权重偏置分别为\(W_i,b_i\)

此时,为了方便公式的推导,我们省略偏置\(b_i\),并且\(a_i\)函数可以看作是一个对角阵,对角的数值是一个变量,和他乘以的数字有关,倘若矩阵乘时候,乘以的数字为正,那么他为1,否则为0。因此公式变形为
\(x_{i} = D_{i, x_{i-1}}W_ix_{i-1}\)

那么,整个神经网络可以表示为
\(f(x) = D_{L, x_{L-1}}W_L\cdot D_{L-1, x_{L-2}}W_{L-1}\cdots D_{1, x_{0}}W_1x_{0}\)

让我们重新考虑一下WGAN提出的1-Lipschitz
Lipschitz约束是对f(x)的梯度提出的要求,下面我们利用上面的公式,对f(x)的梯度进行等价并放缩
\(||\nabla_{x}(f(x))||_2=||D_LW_L\cdots D_1W_1||_2 \le||D_N||_2||W_N||_2\cdots||D_1||_2||W_1||_2\)

这是直接参考的博客,说实话,只是模模糊糊的看了一些,主要是关注如何求解 \(\sigma(A)\)了。而且我认为这个证明有点奇怪,对这个证明有兴趣的话,可以直接参考论文。 SNGAN

首先,我们求解的是矩阵的最大特征值。这里使用的是迭代的方法,可以尝试这样理解。
对于n×n方阵A,在给定一个随机向量x(n×1),假设 A有k个特征向量,对应k个特征值
\((\lambda_1, v_1), (\lambda_2, v_2), (\lambda_3, v_3), (\lambda_k, v_k)\)
将x分解到特征向量的空间上
\(x=\tau_1 v_1+\tau_2 v_2+\cdots+\tau_k v_k\),
那么
\(Ax=\lambda_1 \tau_1 v_1+\lambda_2 \tau_2 v_2+\cdots+\lambda_k \tau_k v_k\),
\(A^2x=\lambda_1^2 \tau_1 v_1+\lambda_2^2 \tau_2 v_2+\cdots+\lambda_k^2 \tau_k v_k\),

\(A^mx=\lambda_1^m \tau_1 v_1+\lambda_2^m \tau_2 v_2+\cdots+\lambda_k^m \tau_k v_k\),
可以看出,当 m无穷大的时候, \(A^mx\)基本上是最大特征值对应特征向量的方向,只需要计算出 m+1次的结果,然后两者向量的长度做除法即可。

当然,如果你已经是正则化了,求解就更为简单了,
u是 \(m\to \infty\) 后正则化的向量,那么\(v = Au\)
\(u\cdot v\)即为最大特征值,因为向量点积公式,\(\vec a\vec b=|a||b|cos\theta, \theta=0, |u|=1\),直接只剩\(\lambda\)。因为这个涉及到我们的weight矩阵不是防止,代码求解的是 \(WW_T\)的特征值,凑合着用了。

SAGAN 代码简介

代码关键点

1.SelfAttention
主要是转置部分,是一个绕点。
这个是在于我提过的转置的问题,原本应该是
\(Query \times Key_T \times Value, Query\in \R^{n\times c}, Key\in \R^{n\times c}, Value\in \R^{n\times c}\)
\(n\)表示像素的\(width \times height\), \(c\)表示\(channels\),但是因为Pytorch C前面,所以说就导致了很绕的代码。
而且主语为什么后来的 weight 进行permute,而不是value进行permute,是为了减少后面的多次permute。
就相当于 \((A \times B_T)_T = B \times A_T\),这个是最绝的地方,感觉要是不说的话,真不知道在写什么。。。

        batch_size, channels, height, width = x.shape
        queries = self.conv_q(x).reshape(batch_size, -1, height * width)
        keys = self.conv_k(x).reshape(batch_size, -1, height * width)
        values = self.conv_v(x).reshape(batch_size, -1, height * width)
        weight = torch.softmax(torch.bmm(queries.permute(0, 2, 1), keys), dim=-1)

        # 正常 self attention 写法
        # y = torch.bmm(weight, values.permute(0, 2, 1)).permute(0, 2, 1)
        y = torch.bmm(values, weight.permute(0, 2, 1)).reshape(batch_size, -1, height, width)
        return x + self.gamma * y, weight

2.SpectralNorm
在于 self.blks 居然不被识别,而且需要 u, v的迭代进程,调用子 module 的forward函数
主要就是调用我们的迭代算法计算 \(\sigma(A)\),u,v进行更新。

3.hinge loss
何时进行 mean 操作,一定要在外面 mean() 操作。

# loss_d = torch.relu(1 - torch.mean(disc_real)) +\
#          torch.relu(1 + torch.mean(disc_fake))
loss_d = torch.relu(1 - disc_real).mean() + \
         torch.relu(1 + disc_fake).mean()

有趣的问题

1、为什么使用self-attention, conv 不能做吗?
    使用 self-attention 是为了做一些长范围依赖性的图像生成任务,主要在于图片有着较强的结构信息、几何形状信息等等,如果是信息若的背景图,使用self-attention效果可能变差。
这里给出self-attention中conv不能及的优点:
    首先要提的是,self-attention能看到全局的视野,但是对于conv而言,他需要很多层才能使特征野达到很大的范围,这些conv层可能存在这些问题。 比如说层次少了,可能表达能力不到,能不到全局信息。优化算法也不一定可以优化到这种底部。或者说是,学习到的信息泛化能力不够强,对于新图片不适用。
    但是将 conv 全部换成 self-attention 也不太现实,计算复杂度,显存也没有办法承受。 因此SAGAN 使用 self-attention 放在最后2层,并且使用 1X1 conv 做了通道维度的降低,减少计算量。

2、GAN常常被视为一种极不稳定,很容易受到超参数选择影响的网络,现在这些工作都在哪些方面?
    GAN的最新科研工作如下:
第一、在于网络结构框架的设计,如pix2pix、SRGAN,CycleGAN 使用的 Residual Block,SAGAN放入self-attention,又比如说PatchGAN
第二、损失函数的重新选定, f-GAN、WGAN、WGAN-GP, hinge损失,都是从损失函数的角度,并且可能加入L1-loss(pix2pix) 来捕获低频信息,Cycle_loss、Identy_loss(Cycle)等方面进行补充网络框架的改进,
第三、正则化方法,比如说使用 BatchNorm, InstanceNorm,LayerNorm,PixelNorm等等。
第四、启发式方法,SpectralNorm可能说的是其中的一种。

当然也有提出一些GAN的衡量指标的论文(往往是顺带着提出的),比如说 InceptionScore(IS), FID(Fréchet Inception Distance),这两个内容会在下一篇博客更新。

参考文献

谱归一化(Spectral Normalization)的理解

谱范数正则(Spectral Norm Regularization)的理解

posted @ 2022-04-23 21:07  lucky_light  阅读(317)  评论(0编辑  收藏  举报