[论文阅读] Vector-quantized Image Modeling with Improved VQGAN
Pre
title: Vector-quantized Image Modeling with Improved VQGAN
accepted: ICLR 2022
paper: https://arxiv.org/abs/2110.04627
code: https://github.com/thuanz123/enhancing-transformers (unofficial)
ref: https://zhuanlan.zhihu.com/p/611689477
关键词:quantization, ViT
阅读理由:看起来比VQGAN厉害很多
Idea
将VQGAN quantizer的CNN换成了ViT,加了俩trick:Factorized codes和l2-normalized codes
Motivation&Solution
-
(m) VQGAN的codebook使用率低下,影响阶段1的重建跟阶段2的多样性
-
(m) Image quantizer质量不好会导致信息丢失
-
(s) 去掉vqgan里面的top-k和top-p采样,只用temperature=1.0的采样
-
(s) 使用Factorized codes、l2-normalized codes 以及 logit-laplace 损失
Background
跟vqgan一样有两个阶段:
- Image Quantization。给定输入图片 256x256 ViT-VQGAN将其编码为32x32的离散codes(8倍下采样),而codebook的大小为8192
- Vector-quantized Image Modeling (VIM)。训练transformer来自回归地预测32x32=1024个token,若是 class-conditioned 图片生成,跟vqgan一样把类别id的token放在图片token前面(输入模型)。加分类头是为了评估无监督学习的质量。
跟vqgan的差别:
- 阶段1的CNN换成ViT,因此解码器先将预测的每个token转换回8x8的图片patch,再将其拼成一整副图片
- 阶段1先对图片进行随机增强再转为token
- ViT-VQGAN编解码器的输入输出都不用激活函数(但也提到logit-laplace要解码器加一个sigmoid激活,可以有更好的重建质量又不会有明显的网格伪影)
- 增大了codebook size,而且不用top-k和top-p采样
- 使用Factorized codes和l2-normalized codes
Method(Model)
Overview
图1 image generation 和image understanding 任务上 ViT-VQGAN (left) 和 Vector-quantized Image Modeling (right) 的overview
VQGAN with vision transformers
表1 ViT-VQGAN 比 CNN-VQGAN 取得了更好的速度-质量权衡,进一步加速了阶段2的训练,吞吐量使用相同的128个CloudTPV4设备进行了基准测试
看表1感觉 ViT-VQGAN的small-small版本就很不错,注意到他这个设计似乎需要解码器更重一些
阶段1 vector-quantization 的目标函数:
其中sg跟vqgan一样,还是一个停止梯度的运算符,$$\beta$$在实验中设置0.25
CODEBOOK LEARNING
Factorized codes
encoder输出先投影到低维空间(768维 -> 32/8维),再在低维空间上查找最近的code,然后将找到的code投影回高维空间。实验证明这样能提高codebook利用率并且改进了重建质量
l2-normalized codes
codebook的向量从正态分布初始化,然后将encoder输出向量z_e跟codebook里的向量e都做l2归一化,这样相当于将向量映射到球体,计算z_e和e的欧氏距离变为计算它们的余弦相似度,提高了训练稳定性和重建质量
VIT-VQGAN training losses
- logits-laplace loss:可视为标准化的L1损失,它假设像素级噪声符合拉普拉斯分布,对码本使用率有贡献
- L2 loss:假设噪声符合高斯分布,对FID有帮助
- adversarial loss:使用了StyleGAN的判别器架构
- perceptual loss:基于VGG,但VGG是用有监督分类loss预训练的,监督信息会泄露到stage2,影响它那分类头的精度,因此无监督学习的都没用,而无条件和class-conditioned 图片合成有用,也对FID有帮助
通过超参数搜索(hyper-parameter sweep)来确定各损失项的权重,
最终联合损失:
表2 Transformer架构 - 阶段1的ViT-VQGAN和阶段2的VIM
先将可学习的离散token ids嵌入,再加上2d位置嵌入,两个嵌入的维度都跟 model dim 一样。在整个序列上使用带因果注意力的堆叠transformer块,所有残差、激活和注意力输出都加上0.1的dropout,最后一层输出加上层归一化。
阶段2的VIM参数量远大于阶段1的ViT-VQGAN,比较大的参数量级。
vector-quantized image modeling
image synthesis
跟VQGAN几乎一样,只是加长序列,而且把CNN换成了Transformer
unsupervised learning
类似Image GPT,把其中一层的token特征序列取平均,然后用一个可学习的softmax层将其投影为class logits,不同的是这里只取其中一个block的输出而非将不同block输出拼在一起。对线性探测最有帮助的特征通常在Transformer中间处的块取到
Experiment
Training Detail
模型如表2所示,分了三种规格:
- ViT-VQGAN-SS(最小的): small encoder/small decoder
- ViT-VQGAN-BB: base encoder/base decoder
- ViT-VQGAN-SL(最大的): small encoder/large decoder
ViT-VQGAN-SL 的编解码器尺寸不对等,说是考虑到阶段2的训练只需要编码器
Dataset
数据集:CelebA-HQ,FFHQ,ImageNet
Results
image quantization:输入图片分辨率256x256,取batch size为128,分布于128个CloudTPUv4上,一共训练500,000个step。Adam优化器,warm up,cosine schedule
表3 验证集上重建的FID,*表示用Gumbel-Softmax重参数化训练,**表示用multi-scale hierarchical codebook训练
表3可看出比起VQGAN,在不使用Gumbel-Softmax和multi-scale hierarchical codebook的情况下,通过加大codebook大小ViT-VQGAN可取得更好的FID
表4 ViT-VQGAN的消融实验。codebook usage计算的是整个测试集上以256为1个batch的code平均使用百分比
后面几行都是跟第一行对比,可以看出基于StyleGAN的判别器就是比PatchGAN好不少,latent dim是指投影的低维空间的维度,16或8比较合适,同时第二个创新点的L2归一化对结果影响也很大
image synthesis:输入图片分辨率256x256,取batch size为1024,一共训练450,000个step。Adam优化器(参数有变),warm up,cosine schedule。然后为了节约显存又用了Adafactor,将第一矩量化为int8并分解第二矩
表5 与无条件图片合成方法对比FID
表6 分辨率256x256在ImageNet上做class-conditional图片合成的FID。acceptance rate是基于ResNet-101分类器的拒绝采样
从表5表6来看,效果都远好于VQGAN,表6用的还是ViT-VQGAN-SS
unsupervised learning:超参跟无条件图片合成一样,用ViT-VQGAN-SS,在Transformer某一个块上用平均池化取得特征,实验表明,中间层(15/36 for large, 10/24 for base)的特征有更好的分类精度
表7 ImageNet上不同无监督学习方法的线性探测精度,DALLE dVAE的图片quantizer用了额外的数据训练而成,而VIM-Large未使用dropout
表7中作者将模型分为两组:判别式预训练模型和生成式预训练模型,本文的方法VIM with ViT-VQGAN比其他生成式模型都好,而且参数量还更小,同时也能取得跟BYOL、DINO等判别式模型相近的性能
放点附录里的样本:
Conclusion
论文无总结,却有ETHICS部分
Critique
根据论文看,效果比vqgan好了不少,但参数量也是多了好多,感觉没法替换vqgan的使用,更别说代码还没开源,不知道非官方代码能否复现论文的效果
Unknown
无