[论文阅读] Vector-quantized Image Modeling with Improved VQGAN

Pre

title: Vector-quantized Image Modeling with Improved VQGAN
accepted: ICLR 2022
paper: https://arxiv.org/abs/2110.04627
code: https://github.com/thuanz123/enhancing-transformers (unofficial)
ref: https://zhuanlan.zhihu.com/p/611689477

关键词:quantization, ViT
阅读理由:看起来比VQGAN厉害很多

Idea

将VQGAN quantizer的CNN换成了ViT,加了俩trick:Factorized codes和l2-normalized codes

Motivation&Solution

  1. (m) VQGAN的codebook使用率低下,影响阶段1的重建跟阶段2的多样性

  2. (m) Image quantizer质量不好会导致信息丢失

  3. (s) 去掉vqgan里面的top-k和top-p采样,只用temperature=1.0的采样

  4. (s) 使用Factorized codes、l2-normalized codes 以及 logit-laplace 损失

Background

跟vqgan一样有两个阶段:

  1. Image Quantization。给定输入图片 256x256 ViT-VQGAN将其编码为32x32的离散codes(8倍下采样),而codebook的大小为8192
  2. Vector-quantized Image Modeling (VIM)。训练transformer来自回归地预测32x32=1024个token,若是 class-conditioned 图片生成,跟vqgan一样把类别id的token放在图片token前面(输入模型)。加分类头是为了评估无监督学习的质量。

跟vqgan的差别:

  1. 阶段1的CNN换成ViT,因此解码器先将预测的每个token转换回8x8的图片patch,再将其拼成一整副图片
  2. 阶段1先对图片进行随机增强再转为token
  3. ViT-VQGAN编解码器的输入输出都不用激活函数(但也提到logit-laplace要解码器加一个sigmoid激活,可以有更好的重建质量又不会有明显的网格伪影)
  4. 增大了codebook size,而且不用top-k和top-p采样
  5. 使用Factorized codes和l2-normalized codes

Method(Model)

Overview

图1 image generation 和image understanding 任务上 ViT-VQGAN (left) 和 Vector-quantized Image Modeling (right) 的overview

VQGAN with vision transformers

表1 ViT-VQGAN 比 CNN-VQGAN 取得了更好的速度-质量权衡,进一步加速了阶段2的训练,吞吐量使用相同的128个CloudTPV4设备进行了基准测试

看表1感觉 ViT-VQGAN的small-small版本就很不错,注意到他这个设计似乎需要解码器更重一些
阶段1 vector-quantization 的目标函数:

\[\begin{equation} L_{VQ}=\|sg[\ z_{e}(x)]-e\|_2^{2}+\beta\|\ z_{e}(x)-sg[e]\|_2^{2} \end{equation} \]

其中sg跟vqgan一样,还是一个停止梯度的运算符,$$\beta$$在实验中设置0.25

CODEBOOK LEARNING

Factorized codes
encoder输出先投影到低维空间(768维 -> 32/8维),再在低维空间上查找最近的code,然后将找到的code投影回高维空间。实验证明这样能提高codebook利用率并且改进了重建质量

l2-normalized codes
codebook的向量从正态分布初始化,然后将encoder输出向量z_e跟codebook里的向量e都做l2归一化,这样相当于将向量映射到球体,计算z_e和e的欧氏距离变为计算它们的余弦相似度,提高了训练稳定性和重建质量

VIT-VQGAN training losses

  1. logits-laplace loss:可视为标准化的L1损失,它假设像素级噪声符合拉普拉斯分布,对码本使用率有贡献
  2. L2 loss:假设噪声符合高斯分布,对FID有帮助
  3. adversarial loss:使用了StyleGAN的判别器架构
  4. perceptual loss:基于VGG,但VGG是用有监督分类loss预训练的,监督信息会泄露到stage2,影响它那分类头的精度,因此无监督学习的都没用,而无条件和class-conditioned 图片合成有用,也对FID有帮助

通过超参数搜索(hyper-parameter sweep)来确定各损失项的权重,
最终联合损失:

\[\begin{equation} L = L_{VQ} + 0.1 L_{Adv} + 0.1 L_{Perceptual} + 0.1 L_{Logit-laplace} + 1.0 L_{2} \end{equation} \]

表2 Transformer架构 - 阶段1的ViT-VQGAN和阶段2的VIM

先将可学习的离散token ids嵌入,再加上2d位置嵌入,两个嵌入的维度都跟 model dim 一样。在整个序列上使用带因果注意力的堆叠transformer块,所有残差、激活和注意力输出都加上0.1的dropout,最后一层输出加上层归一化。
阶段2的VIM参数量远大于阶段1的ViT-VQGAN,比较大的参数量级。

vector-quantized image modeling

image synthesis
跟VQGAN几乎一样,只是加长序列,而且把CNN换成了Transformer

unsupervised learning
类似Image GPT,把其中一层的token特征序列取平均,然后用一个可学习的softmax层将其投影为class logits,不同的是这里只取其中一个block的输出而非将不同block输出拼在一起。对线性探测最有帮助的特征通常在Transformer中间处的块取到

Experiment

Training Detail

模型如表2所示,分了三种规格:

  • ViT-VQGAN-SS(最小的): small encoder/small decoder
  • ViT-VQGAN-BB: base encoder/base decoder
  • ViT-VQGAN-SL(最大的): small encoder/large decoder
    ViT-VQGAN-SL 的编解码器尺寸不对等,说是考虑到阶段2的训练只需要编码器

Dataset

数据集:CelebA-HQ,FFHQ,ImageNet

Results

image quantization:输入图片分辨率256x256,取batch size为128,分布于128个CloudTPUv4上,一共训练500,000个step。Adam优化器,warm up,cosine schedule

表3 验证集上重建的FID,*表示用Gumbel-Softmax重参数化训练,**表示用multi-scale hierarchical codebook训练

表3可看出比起VQGAN,在不使用Gumbel-Softmax和multi-scale hierarchical codebook的情况下,通过加大codebook大小ViT-VQGAN可取得更好的FID

表4 ViT-VQGAN的消融实验。codebook usage计算的是整个测试集上以256为1个batch的code平均使用百分比

后面几行都是跟第一行对比,可以看出基于StyleGAN的判别器就是比PatchGAN好不少,latent dim是指投影的低维空间的维度,16或8比较合适,同时第二个创新点的L2归一化对结果影响也很大

image synthesis:输入图片分辨率256x256,取batch size为1024,一共训练450,000个step。Adam优化器(参数有变),warm up,cosine schedule。然后为了节约显存又用了Adafactor,将第一矩量化为int8并分解第二矩

表5 与无条件图片合成方法对比FID

表6 分辨率256x256在ImageNet上做class-conditional图片合成的FID。acceptance rate是基于ResNet-101分类器的拒绝采样

从表5表6来看,效果都远好于VQGAN,表6用的还是ViT-VQGAN-SS

unsupervised learning:超参跟无条件图片合成一样,用ViT-VQGAN-SS,在Transformer某一个块上用平均池化取得特征,实验表明,中间层(15/36 for large, 10/24 for base)的特征有更好的分类精度

表7 ImageNet上不同无监督学习方法的线性探测精度,DALLE dVAE的图片quantizer用了额外的数据训练而成,而VIM-Large未使用dropout

表7中作者将模型分为两组:判别式预训练模型和生成式预训练模型,本文的方法VIM with ViT-VQGAN比其他生成式模型都好,而且参数量还更小,同时也能取得跟BYOL、DINO等判别式模型相近的性能

放点附录里的样本:

Conclusion

论文无总结,却有ETHICS部分

Critique

根据论文看,效果比vqgan好了不少,但参数量也是多了好多,感觉没法替换vqgan的使用,更别说代码还没开源,不知道非官方代码能否复现论文的效果

Unknown

posted @ 2024-12-05 13:02  NoNoe  阅读(41)  评论(0编辑  收藏  举报