LLM大模型: Maskformer/Mask2Former语义分割原理详解
1、自动驾驶、机器人、电商、监控等行业都涉及到image的sematic segmentation,传统的方式:per-pixel classification,每个像素点都要分类;如果进一步做 instance-level segmentation,可能还要改network architure后重新训练,很麻烦。FAIR在2021年10月份的时候发表了论文:Per-Pixel Classification is Not All You Need for Semantic Segmentation,单看标题就知道他们做Semantic Segmentation不再是Per-Pixel Classification,这帮人又是怎么做的了?原论文给的名称:per-mask classification,图示图下:
- 图的左边:per pixel做分类,一共有K类,那么每个pixel都要计算这K类的概率,一共有k*h*w个数值;
- 图的右边:per mask做分类,一共有K类,加上背景就是K+1类;既然是每个mask都做分类,那么mask有多少了?这里用N表示mask的数量!从图示看,每个mask都是image的一部分区域,然后对这部分区域做prediction,看看属于K+1中的哪类!比如图像中的第一个mask预测后就是building类, 第二个mask预测后就是sky类,最后一个mask啥也不是!
整个过程思路很简单,所以现在最核心的问题来了:这里的mask是怎么划分的?比如上图有100个mask,这100个mask之间是怎么划分地盘的? 换句话说,mask是怎么精准对image中不同object描边的?
2、原论文中网络架构如下图所示:
从颜色来看,就可以直观的分成三部分,用原论文的描述如下:
- pixel - level module:
- A backbone that extracts low resolution features from an image. 从原始image中提取低分辨率的feather F;
- A pixel decoder that gradually upsamples low-resolution features from the output of the backbone to generate high-resolution per-pixel embeddings. 然后通过decoder把低分辨率的feather转成高分辨率的per-pixel embeddings;这一步的结果就是:每个pixel都得到一个embedding representation,用来表示每个pixel的特征,所以这里得到的三维矩阵:C_epsilon * H * W
- transformer module:
- And finally a Transformer decoder that operates on image features to process object queries. transformer融合image的特征F和N个queries(就是上图的mask)
- segmentation module:
- 上路classification loss:经过transformer decoder后,再经过一个MLP进行空间转换,得到C_epsilon * N,也就是每个query/mask的embedding representation,用于描述每个query/mask的特征。然后经过softmax得到了每个query/mask属于哪个object class的概率,所以这里得到的就是 N * (K+1)的二维矩阵;
- 下路binary mask loss:上路得到了每个query/mask的C_epsilon * N特征,又从pixel decoder得到了每个pixel的特征 C_epsilon * H * W ,因为pixel肯定属于某个query/mask,所以把这两相乘,就到了 N * H* W;因为 H * W 代表的就是每个pixel,所以这个三维矩阵表示的是pixel属于某个query/mask的概率!
- 上面就是train过程,接下来就是inference了,也就是最右边的模块: N * (K + 1) 表示query/mask属于某个object类别的概率分布,N * H* W 表示pixel属于某个query/mask的概率分布,这两个相乘,把N去掉,不就得到了 K * H * W了么?H * W是每个pixel像素的位置,K是类别,这个 K * H * W 不就表示每个pixel的K类概率分布了么?是不是感觉这个思路很巧妙了:把N个query/mask作为中间变量latent,巧妙地通过两个矩阵相乘消掉N,得到每个pixel的K类概率分布,这个思路是不是和bayes很像啊!用数学表达式:P(K|H,W) = P(K|N) * P(N|H, W)
现在的问题来了:为什么不直接计算pixel的K类概率分布,而是要通过中间变量mask/query来中转?
- LLM的fine-tune都知道吧!其中一种方式是Lora微调,就是把大matrix分解成两个小matrix相乘,来减少计算量和存储空间,这里的N作用类似:如果K比较大,直接求P(K|H, W)可能计算量比较大,所以这里分成两个小矩阵相乘的形式减少计算量!
- 对于pixel的类别,如果在object内部,那么肯定就属于该object啦,这个很容易区分,但还是有些pos比较难区分:轮廓边缘,也就是准确描边:这上面的pixel肯定不是非黑即白的,需要有个概率分布描述所属类别比较合乎业务逻辑!
- query/mask N动态可调整,如果image的instance增加,可适当增加N来涵盖所有的instance;query/mask N在一定程度上可理解为instance的个数上限,每个query/mask对应一个instance,增加了模型的可解释性,也利于理解
- query/mask N在transformer中和image的feather做cross attention,可捕捉全局特征,理解局部之间的关系,比如transformer的decoder可能会识别出一个query与“行人”类别的关联,并注意到这个query在图像中的位置与另一个表示“车辆”的query有空间上的联系。这种全局的视角使得模型能够推断出,尽管行人的脚部在局部看起来与道路相似,但从整体上看,这个区域应该是属于“行人”的一部分,而不是道路。
3、 maskformer发布后仅仅过了大半年,mask2former接着发布了,又做了哪些改进了?老规矩,先上原论文的图:
原论文介绍如下:Mask2Former overview. Mask2Former adopts the same meta architecture as MaskFormer [14] with a backbone, a pixel decoder and a Transformer decoder. We propose a new Transformer decoder with masked attention instead of the standard cross-attention (Section 3.2.1). To deal with small objects, we propose an efficient way of utilizing high-resolution features from a pixel decoder by feeding one scale of the multi-scale feature to one Transformer decoder layer at a time (Section 3.2.2). In addition, we switch the order of self and cross-attention (i.e., our masked attention), make query features learnable, and remove dropout to make computation more effective (Section 3.2.3). Note that positional embeddings and predictions from intermediate Transformer decoder layers are omitted in this figure for readability. 从架构上看,mask2dormer和maskformer是一样的,没啥区别;最大的改进是使用了masked attention替代了标注你的cross attention。
(1)Compared to the cross-attention used in a standard Transformer decoder which attends to all locations in an image, our masked attention leads to faster convergence and improved performance。maskformer中Image feather F进入transformer的decoder后,和queries做cross atttention,用论文作者的数据是训练一张图片至少需要32GB的显存(One limitation of training universal architectures is the large memory consumption due to high-resolution mask prediction, making them less accessible than the more memory-friendly specialized architectures [6, 24]. For example, MaskFormer [14] can only fit a single image in a GPU with 32G memory),所以这里必须要改。在mask2former中,用的就是mask attention了!先来回顾一下N query/mask的作用:在1代中,N query/mask需要和image feather F做cross attenion,用来确认N query/mask的K+1类的概率分布。其实在一张image中,真正的object能有多少了?图片中很大一部分都是backgroud,这部分pixel参与cross attetion合理么?这里的思路就可以借鉴一下perceiver resampler 了:根据query做downsample,把query对应的object特征提取出来,其他的特征一概不要,这里首当其冲要去掉的就是backgroud啦!那么问题又来了:怎么精确地找到backgroud了?或者说怎么找到正确的mask来屏蔽backgroud了?
这种情况是监督学习,就是有大量标注好训练样本的,所以换个角度理解:训练过程中mask的位置肯定是不可能人为标注的,那就只能让模型自己通过Back proporgation学习了!标准的cross-attention是这么干的,如下图:
masked attention是这么干的:transformer decoder的每一层都要加上mask的值,这个值来自上一层的transformer decoder layer,最早的Mask0来自X0,然后经过transformer decoder layer逐层生成后续的mask;
(2)原论文提出的第二个改进:Second, we use multi-scale high-resolution features which help the model to segment small objects/regions; 这里的multi-scale是通过pixel-decoder体现的。pixel decoder的输入是image feather,经过decoder后生成多个不同scale的high-resolution来捕捉不同层级的信息。比如低分辨率捕捉image的整体全局信息,诸如大的框架、色调、object位置等宏观信息;高分辨率捕捉的是细节,比如毛发、皮肤、衣着、五官等,所以原文说的是high-resolution features which help the model to segment small objects/regions
(3)原论文提出的第三个改进:Third, we propose optimization improvements such as switching the order of self and cross-attention, making query features learnable, and removing dropout; all of which improve performance without additional compute;
Finally, we save 3× training memory without affecting the performance by calculating mask loss on few randomly sampled points. These improvements not only boost the model performance, but also make training significantly easier, making universal architectures more accessible to users with limited compute
参考:
1、https://arxiv.org/pdf/2112.01527 Masked-attention Mask Transformer for Universal Image Segmentation
2、https://github.com/facebookresearch/Mask2Former
3、https://arxiv.org/pdf/2107.06278 Per-Pixel Classification is Not All You Need for Semantic Segmentation
4、https://www.bilibili.com/video/BV1EA22YnEY1?spm_id_from=333.788.videopod.episodes&vd_source=241a5bcb1c13e6828e519dd1f78f35b2&p=2