[Paper Reading] VQ-GAN: Taming Transformers for High-Resolution Image Synthesis

名称

link
[VQ-GAN](Taming Transformers for High-Resolution Image Synthesis)
时间:CVPR2021 oral 21.06
机构:Heidelberg Collaboratory for Image Processing, IWR, Heidelberg University, Germany
代码可参考Stable Diffusion中实现:https://github.com/CompVis/stable-diffusion/tree/main/configs/autoencoder

TL;DR

Transformer优势在于能较好地长距离建模sequence数据,而CNN优势是天生对局部位置关系具有归纳偏差。本文结合两者特征,利用CNN建立context-rich vocabulary的codebook,利用transformer建立高分辨率(long range relations)构成。本文提出VQ-GAN结合Transformer与CNN两者,可以利用语义信息合成高分辨率图像。

GAN基础知识(前置基础知识)

GAN全称Generative Adversarial Network,包含两个功能模块分别是generator G用来生成图像(从noise vector生成图像),discriminator D用来判断生成图像的真实性(二分类网络判断输入图像是真实还是生成,从而拉进两类数据分布距离)。

网络结构

有点像反过来的AE,即Decoder在前面先生成图像,再用Encoder来判断图像的真实性。

实际两个模块是在一个网络里面,只不过训练过程是交替训练的,即训练其中一个模块时另一个模块fix住。


为什么必须要D?训练G的过程为什么fix D?

  1. 生成器仅能生成图像,无法知道真实数据的分布,而D可以方便地将图像映射到高维特征空间,从而比较方便衡量数据集之间的分布,并给G传递梯度告诉其靠近真实数据分布的优化方向。
  2. G在生成图像过程中是没有办法知道自己要生成的图像的全局信息,而D一开始就可以获取到图像全局信息,比较容易判断图像质量与改进方向。
    如果训练G过程不fix D,那么D就会通过自身参数更新来过拟合图像真假判断的任务,不利于生成器的训练。

为什么必须要G?
D本质上还是一个判别任务,而非生成任务,而我们的目标是生成内容,所以必须要有生成器。

VQ-GAN Method

AE的训练

类似于之前VQ-VAE的训练过程,即AE + embedding space的聚类。

整体Loss

其中,\(L_{rec}\)为重建Loss,这里使用的是LPIPS(Learned Perceptual Image Patch Similarity),这种衡量图像间相似度的方法,使用预训练的深度神经网络(如 VGG、AlexNet)提取图像特征,并在特征空间计算相似度。

Transformer

  1. latent feature仍然保持spatial信息,每个cell上embedding特征在codebook里面都有唯一编号,这使得一个latent feature可以映射为一个sequence序列,再使用seq2seq的Transformer对latent feature进行特征融合。优势是可以在低分辨率的latent space进行特征融合,使得即使高分辨率图像的生成也能充分融合long-range的context。
  2. 方便conditional信息融合:transformer另一个优势是conditional信息也可编码为codebook编号,再利用Transformer auto-regressive来生成图像。

实际生成过程并没有完全使用所有前序code编码,而是使用一个sliding windows中编码生成某个coding embedding参见下图。个人理解是利用了CNN的位置归纳偏置简化了计算。

CodeReading

以SD官方代码中autoencoder_kl_32x32x4.yaml来说明VAE的训练实现。

AutoencoderKL

AE的模型结构参考这三个函数:

  • encoder是一个ResNet结构,用来将原图下采样为yaml中配置的特征图latent feature
  • quant_conv与post_quant_conv作为latent feature前后处理的普通conv理解即可
  • posterior是根据每个样本生成的统计分布或者从分布中采样的一个DiagonalGaussianDistribution类的实例
    def encode(self, x):
        h = self.encoder(x)
        moments = self.quant_conv(h)
        posterior = DiagonalGaussianDistribution(moments)
        return posterior

    def decode(self, z):
        z = self.post_quant_conv(z)
        dec = self.decoder(z)
        return dec

    def forward(self, input, sample_posterior=True):
        posterior = self.encode(input)
        if sample_posterior:
            z = posterior.sample()
        else:
            z = posterior.mode()
        dec = self.decode(z)
        return dec, posterior

AE的训练过程参考training_step,其中optimizer_idx从pytorch_lighting中获取,即下面配置了几个optimizer,则会轮流输入对应的optimizer_idx分别用loss更新对应的optimizer,所以会依次更新AE的参数(即generator),以及Discriminator的参数

    def training_step(self, batch, batch_idx, optimizer_idx):
        inputs = self.get_input(batch, self.image_key)
        reconstructions, posterior = self(inputs)

        if optimizer_idx == 0:
            # train encoder+decoder+logvar
            aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
                                            last_layer=self.get_last_layer(), split="train")
            self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
            self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
            return aeloss

        if optimizer_idx == 1:
            # train the discriminator
            discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
                                                last_layer=self.get_last_layer(), split="train")

            self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
            return discloss

    def configure_optimizers(self):
        lr = self.learning_rate
        opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
                                  list(self.decoder.parameters())+
                                  list(self.quant_conv.parameters())+
                                  list(self.post_quant_conv.parameters()),
                                  lr=lr, betas=(0.5, 0.9))
        opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
                                    lr=lr, betas=(0.5, 0.9))
        return [opt_ae, opt_disc], []

DiagonalGaussianDistribution

  • 无参数,每次iter动态生成,以quant_conv出的feature作为初始化,分解为mean与var
  • sample(): 根据mean与var采样出的feature
  • kl(): 根据mean与var计算出KL散度
  • nll(): Negative Log Likelihood,用来根据mean与var计算高斯分布的似然
  • mode(): "众数",高斯分面的均值、中位数、众数为同一个值,所以这里直接return mean

LPIPSWithDiscriminator

训练VAE Loss Function

  • 初始化: 包含了带可学参数的NLayerDiscriminator的定义
  • forward:
    • 重建loss:包含重生成与原图之间的L1相似度与LPIPS相似度
    • 分布loss:posteriors中的高斯分布与标准正态分布的距离
  • optimizer_idx == 0时,discriminator仅用来给VAE生成图打分
  • optimizer_idx == 1时,用discriminator分别对真实与生成图打分,学到真实/生成样本的判别能力

总结

  • 可以本节分析的SD中的VAE实际并没有使用VQ(vector quantized)的方式训练,而是直接连续采样
  • VAE是作为GAN中的Generator来进行训练的

Experiment


总结与发散

对于本文motivation上的理解:
利用CNN的位置归纳偏置建立context rich codebook(因为CNN提取的特征能够保持spatial信息,所以codebook也具有spatial信息,并且每个cell包含了一定context语义信息)
利用Transformer长时序能力生成生高分辨率图像(Transformer本身就有长时序优势,在这里面Transformer应用于低分辨率的latent feature生成,再Decode为高分辨图,减轻了计算量与生成难度)
整体理解,比较类似于VQ-VAE,不同点是:
1.重建后的图像是用adversarial loss监督,而VQ-VAE是直接使用重建Loss;
2.直接将latent feature每个embedding映射为编码,通过transformer seq2seq模型生成编码序列来生成图像,VQ-VAE直接训练了一个PixelCNN生成特征。

相关链接

link
https://zhuanlan.zhihu.com/p/515214329
https://speech.ee.ntu.edu.tw/~hylee/ml/2023-spring.php

资料查询

折叠Title FromChatGPT(提示词:XXX)
posted @ 2024-04-01 23:08  fariver  阅读(605)  评论(0)    收藏  举报