SAGAN 4/21-4/23 踩坑日记
正文
这几天为了毕业设计项目,忙的不行。因为本人毕设是关于医学图像数据增强算法方面的研究,近期主要是阅读并复现了一篇论文。论文的名称是SAGAN。下面,我将论文和代码进行简要的介绍,主要说的是其中的坑点!
Self-Attention Generative Adversarial Networks Han Zhang, Ian J. Goodfellow Published 30 March 2017 (Citations 1895)
近些年来,随着人工智能,深度学习的不断发展,一大批一大批的网络框架也随之兴起,比如说大名鼎鼎的CNN
、RNN
系列。但是大多数的进展往往是围绕着判别模型 Discriminator
而言,对于生成模型 Generator
的网络效果却并没有那么好。
2014年,生成对抗网络之父 Ian 提出了一个全新的网络模型框架,令人耳目一新,而这个网络模型框架就是生成对抗网络 GAN
。它和一般的网络架构不同,该模型同时含有一个生成器Generator
,和一个判别器Discriminator
,两个模型相互拮抗、共同发展。
虽然说GAN
网络结构提出,并且取得一定的效果,但是他的弊端也很严重,主要体现在以下方面。
- 神经网络难以训练,难以收敛;
- 网络训练易出现
mode collapse
问题,模式崩溃; - 网络的衡量指标难以定义;
首先,是网络难以收敛的问题,该问题很好理解。在平时的网络调参训练过程中,仅仅是判别器的调参都较为难调,何况在这里是存在两个网络生成器G
,和判别器D
呢?
第二,mode collapse
问题,指的是Generator 生成的图片过于单一,多样性差
第三,生成图片的质量往往具有主观性,很难被及其所衡量。
当然本文的重点不是介绍上面这几种问题,本是在于对SAGAN
的介绍
SAGAN 论文介绍与踩坑
首先SAGAN
论文是为了解决在原本图像生成任务中,尤其是对于含有大量结构信息、几何形状图像的生成。在原本的图像生成办法中,往往会存在一些问题,比如说生成的狗狗只有三条腿、甚至还可能少了一条等等。Ian认为,这是因为,像素和像素之间关联可能太小,导致不协调的问题。因此引入了大名鼎鼎的 self attention
机制,来增强图像之间的关联程度。
同时,为了提高 GAN
训练过程的稳定性,文中重点使用了两种方法,TTUR
和 Spectral Normalization
。 下面,我对该内容进行介绍。
self-attention
假设,大家对 self attention
机制已经有了一定的了解,我对 self-attention
内容进行简要的回顾。举一个简单的 multi-head attention
多头注意力, 如图1所示。 注意力机制的思想,就是给定你n
个 query 查询,然后你自身还有 m
个 key-value
键值对,然后给出 query
对应的数值。
当然,一个很简单的做法。将 Query
和 Key
映射到同一维度,然后计算 某个 query
和所有的 key
之间的相似度,然后加权到 value
上。这就是 attention
,关键点就是在于权重。 multi-head
指的就是你有很多个attention
,然后将他们 concat
起来就好了。看上去很像是我们 CNN
中卷积的通道数,这里一个 head
就是一个通道。

图1
那么,我们的 SAGAN
中的 Attention
是什么样子呢?借用论文中的一张图。我们先不考虑Batch_size 大小。
输入一张特征图, 形状为(c, h, w),将其作为 query, key, value 进行 1X1 conv 映射,得到 f(x), g(x), h(x), 之后 f(x) 转置后,和 g(x) 进行矩阵乘,得到 softmax 操作 query 和 key 的权重,然后和 value 进行矩阵乘法,得到加权的 value_sum,也就是我们 query 之后的结果,最后再次进行 conv 1X1 操作即可。按照图来看,可以看出的进行什么计算,但是感觉他这样运算毫无道理,下面我来一点一点的讲解。
想一想,我们作者使用 attention
的目的是什么,在于寻找像素和像素的关联,也就是说按道理,我们 query 应该是的正常形状为 (pix_num, c),pix_num表示的是像素的数量=width*height, c是通道数。 同样因为是self-attention
key,value的大小形状和 query一样。
那么我们下一步,直接计算 query和key的关联度,得到未进行softmax前的attention权重即可。那很简单,query(pix_num, c) 和 (c, pix_num)进行矩阵乘法,得到矩阵表示对的点积,将其作为query和value的相似度,也就是权重,然后进行softmax操作。 这里的话,你可能有点因为,为什么图中是 f(x)或者是说query 进行tranpose了呢?这不是说的有些矛盾了么? 用过 Pytorch 的话,你会明白,原始的 query形状为(c, pix_num),也就是说,咱们上文中的 query和key,已经是转置后的了。我认为这是原文的一大坑点,因为他并没有明确的给出,我们的attention对象应该是像素,虽然你可以get到,应该是pixel
softmax操作,操作对象是最后一维,因为这代表某一个 query,他所有 key的权重和为1。剩下的就是简单的将权重加到 value 上,然后最后 Residual 操作了,从图中就很好理解了。

图2
SAGAN论文,对上图给出了数学公式,不过说实话,这个公式也有些令人头疼。(因为我懒的用 LaTeX写了,直接给截图)


首先要吐槽的一点为什么是列向量,第一次看公式的我直接头都大了,这里的和编程习惯完全不同,他指的是数学上列向量优先的列向量。
然后,计算到 我可以理解,毕竟是我刚刚解释过的 query和key的加权,但是不幸的是 softmax 之后, 居然转置了。。。这一部分公式写的很古怪,甚至我事后写完代码也觉着很古怪,我觉着借助代码理解起来会好很多。
TTUR
从实践的角度上,GAN训练过程中,经过正则化的判别器Discriminator需要多次更新,当Generator更新一次的时候。本文使用的方法是,提高 Discriminator 的Learning Rate,以达到类似的效果。而且相比更新一次G,更新K次D,TTUR的计算量更小,实践起来也很简单,对优化器的参数Learning Rate进行调整即可
Spectral Normalization
在GAN的经典论中 WGAN 中,提出了 Wasserstein Distance 从损失函数的方面减缓 GAN训练不稳定的问题。
其中,Lipschitz约束简单而言就是:要求在整个f(x)的定义域内有
但是,但是拉普拉斯平滑问题解决的方法却并不高明,对权重进行 clip 操作,使得容易发生梯度消失的现象。
对此WGAN-GP,提出了 Grapdient Penalty 的解决方法
本文选取的是另一种解决方法,来自于 SNGAN 这篇论文。
首先,我们考虑一层卷积、或者是线性层,加上激活函数ReLU,可以写为:
,输入为,输出为,激活函数为,权重偏置分别为
此时,为了方便公式的推导,我们省略偏置,并且函数可以看作是一个对角阵,对角的数值是一个变量,和他乘以的数字有关,倘若矩阵乘时候,乘以的数字为正,那么他为1,否则为0。因此公式变形为
那么,整个神经网络可以表示为
让我们重新考虑一下WGAN提出的1-Lipschitz
Lipschitz约束是对f(x)的梯度提出的要求,下面我们利用上面的公式,对f(x)的梯度进行等价并放缩
这是直接参考的博客,说实话,只是模模糊糊的看了一些,主要是关注如何求解 了。而且我认为这个证明有点奇怪,对这个证明有兴趣的话,可以直接参考论文。 SNGAN
首先,我们求解的是矩阵的最大特征值。这里使用的是迭代的方法,可以尝试这样理解。
对于n×n方阵A,在给定一个随机向量x(n×1),假设 A有k个特征向量,对应k个特征值
将x分解到特征向量的空间上
,
那么
,
,
,
可以看出,当 m无穷大的时候, 基本上是最大特征值对应特征向量的方向,只需要计算出 m+1次的结果,然后两者向量的长度做除法即可。
当然,如果你已经是正则化了,求解就更为简单了,
u是 后正则化的向量,那么
即为最大特征值,因为向量点积公式,,直接只剩。因为这个涉及到我们的weight矩阵不是防止,代码求解的是 的特征值,凑合着用了。
SAGAN 代码简介
代码关键点
1.SelfAttention
主要是转置部分,是一个绕点。
这个是在于我提过的转置的问题,原本应该是
表示像素的, 表示,但是因为Pytorch C前面,所以说就导致了很绕的代码。
而且主语为什么后来的 weight 进行permute,而不是value进行permute,是为了减少后面的多次permute。
就相当于 ,这个是最绝的地方,感觉要是不说的话,真不知道在写什么。。。
batch_size, channels, height, width = x.shape queries = self.conv_q(x).reshape(batch_size, -1, height * width) keys = self.conv_k(x).reshape(batch_size, -1, height * width) values = self.conv_v(x).reshape(batch_size, -1, height * width) weight = torch.softmax(torch.bmm(queries.permute(0, 2, 1), keys), dim=-1) # 正常 self attention 写法 # y = torch.bmm(weight, values.permute(0, 2, 1)).permute(0, 2, 1) y = torch.bmm(values, weight.permute(0, 2, 1)).reshape(batch_size, -1, height, width) return x + self.gamma * y, weight
2.SpectralNorm
在于 self.blks 居然不被识别,而且需要 u, v的迭代进程,调用子 module 的forward函数
主要就是调用我们的迭代算法计算 ,u,v进行更新。
3.hinge loss
何时进行 mean 操作,一定要在外面 mean() 操作。
# loss_d = torch.relu(1 - torch.mean(disc_real)) +\ # torch.relu(1 + torch.mean(disc_fake)) loss_d = torch.relu(1 - disc_real).mean() + \ torch.relu(1 + disc_fake).mean()
有趣的问题
1、为什么使用self-attention, conv 不能做吗?
使用 self-attention 是为了做一些长范围依赖性的图像生成任务,主要在于图片有着较强的结构信息、几何形状信息等等,如果是信息若的背景图,使用self-attention效果可能变差。
这里给出self-attention中conv不能及的优点:
首先要提的是,self-attention能看到全局的视野,但是对于conv而言,他需要很多层才能使特征野达到很大的范围,这些conv层可能存在这些问题。 比如说层次少了,可能表达能力不到,能不到全局信息。优化算法也不一定可以优化到这种底部。或者说是,学习到的信息泛化能力不够强,对于新图片不适用。
但是将 conv 全部换成 self-attention 也不太现实,计算复杂度,显存也没有办法承受。 因此SAGAN 使用 self-attention 放在最后2层,并且使用 1X1 conv 做了通道维度的降低,减少计算量。
2、GAN常常被视为一种极不稳定,很容易受到超参数选择影响的网络,现在这些工作都在哪些方面?
GAN的最新科研工作如下:
第一、在于网络结构框架的设计,如pix2pix、SRGAN,CycleGAN 使用的 Residual Block,SAGAN放入self-attention,又比如说PatchGAN
第二、损失函数的重新选定, f-GAN、WGAN、WGAN-GP, hinge损失,都是从损失函数的角度,并且可能加入L1-loss(pix2pix) 来捕获低频信息,Cycle_loss、Identy_loss(Cycle)等方面进行补充网络框架的改进,
第三、正则化方法,比如说使用 BatchNorm, InstanceNorm,LayerNorm,PixelNorm等等。
第四、启发式方法,SpectralNorm可能说的是其中的一种。
当然也有提出一些GAN的衡量指标的论文(往往是顺带着提出的),比如说 InceptionScore(IS), FID(Fréchet Inception Distance),这两个内容会在下一篇博客更新。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!