SAGAN 4/21-4/23 踩坑日记

 

正文

    这几天为了毕业设计项目,忙的不行。因为本人毕设是关于医学图像数据增强算法方面的研究,近期主要是阅读并复现了一篇论文。论文的名称是SAGAN。下面,我将论文和代码进行简要的介绍,主要说的是其中的坑点!

Self-Attention Generative Adversarial Networks Han Zhang, Ian J. Goodfellow Published 30 March 2017 (Citations 1895)

    近些年来,随着人工智能,深度学习的不断发展,一大批一大批的网络框架也随之兴起,比如说大名鼎鼎的CNNRNN系列。但是大多数的进展往往是围绕着判别模型 Discriminator 而言,对于生成模型 Generator 的网络效果却并没有那么好。

    2014年,生成对抗网络之父 Ian 提出了一个全新的网络模型框架,令人耳目一新,而这个网络模型框架就是生成对抗网络 GAN。它和一般的网络架构不同,该模型同时含有一个生成器Generator,和一个判别器Discriminator,两个模型相互拮抗、共同发展。

    虽然说GAN网络结构提出,并且取得一定的效果,但是他的弊端也很严重,主要体现在以下方面。

  1. 神经网络难以训练,难以收敛;
  2. 网络训练易出现mode collapse问题,模式崩溃;
  3. 网络的衡量指标难以定义;

首先,是网络难以收敛的问题,该问题很好理解。在平时的网络调参训练过程中,仅仅是判别器的调参都较为难调,何况在这里是存在两个网络生成器G,和判别器D呢?
第二,mode collapse问题,指的是Generator 生成的图片过于单一,多样性差
第三,生成图片的质量往往具有主观性,很难被及其所衡量。

当然本文的重点不是介绍上面这几种问题,本是在于对SAGAN的介绍

SAGAN 论文介绍与踩坑

     首先SAGAN论文是为了解决在原本图像生成任务中,尤其是对于含有大量结构信息、几何形状图像的生成。在原本的图像生成办法中,往往会存在一些问题,比如说生成的狗狗只有三条腿、甚至还可能少了一条等等。Ian认为,这是因为,像素和像素之间关联可能太小,导致不协调的问题。因此引入了大名鼎鼎的 self attention机制,来增强图像之间的关联程度。

     同时,为了提高 GAN 训练过程的稳定性,文中重点使用了两种方法,TTURSpectral Normalization。 下面,我对该内容进行介绍。

self-attention

     假设,大家对 self attention 机制已经有了一定的了解,我对 self-attention内容进行简要的回顾。举一个简单的 multi-head attention 多头注意力, 如图1所示。 注意力机制的思想,就是给定你n个 query 查询,然后你自身还有 mkey-value 键值对,然后给出 query 对应的数值。
     当然,一个很简单的做法。将 QueryKey 映射到同一维度,然后计算 某个 query 和所有的 key 之间的相似度,然后加权到 value上。这就是 attention,关键点就是在于权重。 multi-head 指的就是你有很多个attention,然后将他们 concat 起来就好了。看上去很像是我们 CNN中卷积的通道数,这里一个 head 就是一个通道。


图1

    那么,我们的 SAGAN 中的 Attention 是什么样子呢?借用论文中的一张图。我们先不考虑Batch_size 大小。
    输入一张特征图, 形状为(c, h, w),将其作为 query, key, value 进行 1X1 conv 映射,得到 f(x), g(x), h(x), 之后 f(x) 转置后,和 g(x) 进行矩阵乘,得到 softmax 操作 query 和 key 的权重,然后和 value 进行矩阵乘法,得到加权的 value_sum,也就是我们 query 之后的结果,最后再次进行 conv 1X1 操作即可。按照图来看,可以看出的进行什么计算,但是感觉他这样运算毫无道理,下面我来一点一点的讲解。

    想一想,我们作者使用 attention 的目的是什么,在于寻找像素和像素的关联,也就是说按道理,我们 query 应该是的正常形状为 (pix_num, c),pix_num表示的是像素的数量=width*height, c是通道数。 同样因为是self-attention key,value的大小形状和 query一样。
    那么我们下一步,直接计算 query和key的关联度,得到未进行softmax前的attention权重即可。那很简单,query(pix_num, c) 和 KeyT (c, pix_num)进行矩阵乘法,得到矩阵S,Si,j表示queryivaluej对的点积,将其作为query和value的相似度,也就是权重,然后进行softmax操作。 这里的话,你可能有点因为,为什么图中是 f(x)或者是说query 进行tranpose了呢?这不是说的有些矛盾了么? 用过 Pytorch 的话,你会明白,原始的 query形状为(c, pix_num),也就是说,咱们上文中的 query和key,已经是转置后的了。我认为这是原文的一大坑点,因为他并没有明确的给出,我们的attention对象应该是像素,虽然你可以get到,应该是pixel
    softmax操作,操作对象是最后一维,因为这代表某一个 query,他所有 key的权重和为1。剩下的就是简单的将权重加到 value 上,然后最后 Residual 操作了,从图中就很好理解了。


图2

    SAGAN论文,对上图给出了数学公式,不过说实话,这个公式也有些令人头疼。(因为我懒的用 LaTeX写了,直接给截图)


    首先要吐槽的一点xi为什么是列向量,第一次看公式的我直接头都大了,这里的xi和编程习惯完全不同,他指的是数学上列向量优先的列向量。
    然后,计算到 Si,j我可以理解,毕竟是我刚刚解释过的 query和key的加权,但是不幸的是 softmax 之后,β 居然转置了。。。这一部分公式写的很古怪,甚至我事后写完代码也觉着很古怪,我觉着借助代码理解起来会好很多。

TTUR

    从实践的角度上,GAN训练过程中,经过正则化的判别器Discriminator需要多次更新,当Generator更新一次的时候。本文使用的方法是,提高 Discriminator 的Learning Rate,以达到类似的效果。而且相比更新一次G,更新K次D,TTUR的计算量更小,实践起来也很简单,对优化器的参数Learning Rate进行调整即可

Spectral Normalization

     在GAN的经典论中 WGAN 中,提出了 Wasserstein Distance 从损失函数的方面减缓 GAN训练不稳定的问题。

W(Pr,Pg)=supLip 1Expr[f(x)]Expg[f(x)](1)
其中,Lipschitz约束简单而言就是:要求在整个f(x)的定义域内有||f(x)f(x)||||xx||M

但是,但是拉普拉斯平滑问题解决的方法却并不高明,对权重进行 clip 操作,使得容易发生梯度消失的现象。
对此WGAN-GP,提出了 Grapdient Penalty 的解决方法

本文选取的是另一种解决方法,来自于 SNGAN 这篇论文。
首先,我们考虑一层卷积、或者是线性层,加上激活函数ReLU,可以写为:
xi=ai(Wixi1+bi),输入为xi1,输出为xi,激活函数为ai,权重偏置分别为Wi,bi

此时,为了方便公式的推导,我们省略偏置bi,并且ai函数可以看作是一个对角阵,对角的数值是一个变量,和他乘以的数字有关,倘若矩阵乘时候,乘以的数字为正,那么他为1,否则为0。因此公式变形为
xi=Di,xi1Wixi1

那么,整个神经网络可以表示为
f(x)=DL,xL1WLDL1,xL2WL1D1,x0W1x0

让我们重新考虑一下WGAN提出的1-Lipschitz
Lipschitz约束是对f(x)的梯度提出的要求,下面我们利用上面的公式,对f(x)的梯度进行等价并放缩
||x(f(x))||2=||DLWLD1W1||2||DN||2||WN||2||D1||2||W1||2

这是直接参考的博客,说实话,只是模模糊糊的看了一些,主要是关注如何求解 σ(A)了。而且我认为这个证明有点奇怪,对这个证明有兴趣的话,可以直接参考论文。 SNGAN

首先,我们求解的是矩阵的最大特征值。这里使用的是迭代的方法,可以尝试这样理解。
对于n×n方阵A,在给定一个随机向量x(n×1),假设 A有k个特征向量,对应k个特征值
(λ1,v1),(λ2,v2),(λ3,v3),(λk,vk)
将x分解到特征向量的空间上
x=τ1v1+τ2v2++τkvk,
那么
Ax=λ1τ1v1+λ2τ2v2++λkτkvk,
A2x=λ12τ1v1+λ22τ2v2++λk2τkvk,

Amx=λ1mτ1v1+λ2mτ2v2++λkmτkvk,
可以看出,当 m无穷大的时候, Amx基本上是最大特征值对应特征向量的方向,只需要计算出 m+1次的结果,然后两者向量的长度做除法即可。

当然,如果你已经是正则化了,求解就更为简单了,
u是 m 后正则化的向量,那么v=Au
uv即为最大特征值,因为向量点积公式,ab=|a||b|cosθ,θ=0,|u|=1,直接只剩λ。因为这个涉及到我们的weight矩阵不是防止,代码求解的是 WWT的特征值,凑合着用了。

SAGAN 代码简介

代码关键点

1.SelfAttention
主要是转置部分,是一个绕点。
这个是在于我提过的转置的问题,原本应该是
Query×KeyT×Value,QueryRn×c,KeyRn×c,ValueRn×c
n表示像素的width×height, c表示channels,但是因为Pytorch C前面,所以说就导致了很绕的代码。
而且主语为什么后来的 weight 进行permute,而不是value进行permute,是为了减少后面的多次permute。
就相当于 (A×BT)T=B×AT,这个是最绝的地方,感觉要是不说的话,真不知道在写什么。。。

batch_size, channels, height, width = x.shape
queries = self.conv_q(x).reshape(batch_size, -1, height * width)
keys = self.conv_k(x).reshape(batch_size, -1, height * width)
values = self.conv_v(x).reshape(batch_size, -1, height * width)
weight = torch.softmax(torch.bmm(queries.permute(0, 2, 1), keys), dim=-1)
# 正常 self attention 写法
# y = torch.bmm(weight, values.permute(0, 2, 1)).permute(0, 2, 1)
y = torch.bmm(values, weight.permute(0, 2, 1)).reshape(batch_size, -1, height, width)
return x + self.gamma * y, weight

2.SpectralNorm
在于 self.blks 居然不被识别,而且需要 u, v的迭代进程,调用子 module 的forward函数
主要就是调用我们的迭代算法计算 σ(A),u,v进行更新。

3.hinge loss
何时进行 mean 操作,一定要在外面 mean() 操作。

# loss_d = torch.relu(1 - torch.mean(disc_real)) +\
# torch.relu(1 + torch.mean(disc_fake))
loss_d = torch.relu(1 - disc_real).mean() + \
torch.relu(1 + disc_fake).mean()

有趣的问题

1、为什么使用self-attention, conv 不能做吗?
    使用 self-attention 是为了做一些长范围依赖性的图像生成任务,主要在于图片有着较强的结构信息、几何形状信息等等,如果是信息若的背景图,使用self-attention效果可能变差。
这里给出self-attention中conv不能及的优点:
    首先要提的是,self-attention能看到全局的视野,但是对于conv而言,他需要很多层才能使特征野达到很大的范围,这些conv层可能存在这些问题。 比如说层次少了,可能表达能力不到,能不到全局信息。优化算法也不一定可以优化到这种底部。或者说是,学习到的信息泛化能力不够强,对于新图片不适用。
    但是将 conv 全部换成 self-attention 也不太现实,计算复杂度,显存也没有办法承受。 因此SAGAN 使用 self-attention 放在最后2层,并且使用 1X1 conv 做了通道维度的降低,减少计算量。

2、GAN常常被视为一种极不稳定,很容易受到超参数选择影响的网络,现在这些工作都在哪些方面?
    GAN的最新科研工作如下:
第一、在于网络结构框架的设计,如pix2pix、SRGAN,CycleGAN 使用的 Residual Block,SAGAN放入self-attention,又比如说PatchGAN
第二、损失函数的重新选定, f-GAN、WGAN、WGAN-GP, hinge损失,都是从损失函数的角度,并且可能加入L1-loss(pix2pix) 来捕获低频信息,Cycle_loss、Identy_loss(Cycle)等方面进行补充网络框架的改进,
第三、正则化方法,比如说使用 BatchNorm, InstanceNorm,LayerNorm,PixelNorm等等。
第四、启发式方法,SpectralNorm可能说的是其中的一种。

当然也有提出一些GAN的衡量指标的论文(往往是顺带着提出的),比如说 InceptionScore(IS), FID(Fréchet Inception Distance),这两个内容会在下一篇博客更新。

参考文献

谱归一化(Spectral Normalization)的理解

谱范数正则(Spectral Norm Regularization)的理解

posted @   lucky_light  阅读(422)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
点击右上角即可分享
微信分享提示