LLM大模型: Segment Anything Model原理详解
meta在2023.4.5又发了image sematic segmentation的文章,名字就叫Segment Anything;学术圈有个潜规则:title越简单,事情越大,比如7年前的那篇 attention is all you need,直接提升了nlp的层次!这次的Segment Anything同样也很简单,这次又有哪些break through innovation?
1、(1)论文刚开始,给出了模型的交互方式:点、框、mask描边、text都能作为prompt,然后和image一起输入,经过model的处理后,输出就是valid mask了!怎么样,是不是很符合人的使用习惯?
另一个靓点:所谓的data engine,先人工标注少量的高质量数据集,用来训练"粗糙"的SAM;然后用粗糙的SAM做语义分割,期间配个人工检查标记漏掉的mask,用以完善数据集,再继续迭代训练SAM;如此往复,不停迭代,直到得到高质量的数据集和准确率高的model,整个过程的思路是不是和loss+back proportion很相似啊!
2、 SAM网络结构如下:
- 原始的image,通过encoder后转成image embedding;为了和transformer架构兼容,这里推荐使用vit提取image embedding
- 常见的promt:
- point、box:这两种prompt都和位置相关,所以可以用positional encoder编码,论文原话:We represent points and boxes by positional encodings [95] summed with learned embeddings for each prompt type
- text是文本,最常见的就是BERT编码了,论文用的CLIP,论文原话:free-form text with an off-the-shelf text encoder from CLIP [82]
- 最麻烦的就是mask prompt了:这个是用户手动的描边,可能不精准,只是个大概的范围,需要SAM进一步精确描边。这里的mask promt本质也是个图片,所以这里用conv提取特征【其实conv也是attn,没本质区别】,转成embedding,论文原话:Dense prompts (i.e., masks) are embedded using convolutions and summed element-wise with the image embedding
- 四种promt经过encoder编码后,进入mask decoder,就生成了valid mask图片;
- 为了避免歧义,model默认输出三个mask segmentation,比如上图的剪刀有3个:剪刀全身、剪刀两个耳朵,剪刀的一个耳朵,每个mask segmentation都有一个socre,计算方法为IOU,score越高,mask segmentation越准确
所以整个网络结构最核心的就是mask encoder了,这个又是怎么处理image和promt的embedding的了?
3、(1)mask encoder的网络架构如下:
论文原始描述:Figure 14: Details of the lightweight mask decoder. A two-layer decoder updates both the image embedding and prompt tokens via cross-attention. Then the image embedding is upscaled, from which the updated output tokens are used to dynamically predict masks. (Not illustrated for figure clarity: At every attention layer, positional encodings are added to the image embedding, and the entire original prompt token (including position encoding) is re-added to the token queries and keys.)
- image embedding:原始的image通过vit后转成256 * 64 * 64这么大,64 * 64 可以看成是H' * W',256这种channel可以看成是embedding
- promt token: 就是point、box、text、mask等编码后的embedding,每个token也都转成256维的向量,和image embedding的dimension保持一致
- output token:从名字看,就知道是输出结果的token了。这部分的token embedding是要动态更新的,论文原话:Then the image embedding is upscaled, from which the updated output tokens are used to dynamically predict masks;instead of predicting a single mask, we use a small number of output tokens and predict multiple masks simultaneously. By default we predict three masks;
输出准备完毕,就是最关键的mask decoder环节了,直接涉及到输出的mask是否准确;从下网上看:
- self attn:这个容易理解,主要是output token和promt token内部做attn,核心是查看output和prompt是否接近,来确定output是否正确;
- token to image attn:把image的特征和tokens的特征做融合,这里是token主动和image做cross attn
- MLP:类似transformer block的FFN
- image to token attn:这里居然是image主动和token做cross attn,这也是SAM的创新点之一;
- To ensure the decoder has access to critical geometric information, the positional encodings are added to the image embedding whenever they participate in an attention layer:模型在mask decode过程中,能保留每个pixel(或特征位置)的position,避免decode过程中pixel的原始position信息丢失,从而帮助mask decoder更精确地生成spatial信息相关的mask,我个人觉得这个思路和renet接近
- Additionally, the entire original prompt tokens (including their positional encodings) are re-added to the updated tokens whenever they participate in an attention layer:无论在attn层中更新了多少次promt,mask decoder都会把最初的promt token 连同其position embedding一起重新注入。这使得模型在decode时既保留了prompt的postion和类型信息,也允许它在此基础上增加或调整新的信息
以上就是组接近transformer block的架构了,比较神奇的是这个block居然只有2个,所以SAM模型的参数也很小,才641M个,bin文件不到2.4GB(https://huggingface.co/facebook/sam-vit-huge 有详情)!这么小的参数,为啥效果这么好?我个人认为核心还是数据质量好,也就是data engine的思路好!经过2个block,image和tokens的信息互相融合后,
- 2x conv trans:这里做upsample ,核心目的是还原原始的image,才能在原始尺寸的image上做mask segmentation
- token to image attn:两部分的信息继续融合
- MLP:只有3层,核心是做channel维度适配;we pass the updated output token embedding to a small 3-layer MLP that outputs a vector matching the channel dimension of the upscaled image embedding
- masks:dot product per mask * output token per mask(经过MLP):这步骤和maskformer中的最后一个drop神似:classification loss * binary mask loss最终得到 K*H*W,也就是每个pixel的K个object分类的概率分布;Finally, we predict a mask with a spatially point-wise product between the upscaled image embedding and the MLP’s output;这里的到的是每个pixel属于N个query/mask的概率分布(或者说单个N包含哪些pixel,不包含哪些pixel;比如3*H*W,就是第3个mask包含哪些pixel,不包含哪些pixel),也就是N*H*W!
- 最后一个IOU:mask出来的object和标注ibject的交集,越大说明mask越准确
(2) loss:输出的masks对不对了? 怎么评价输出的masks好坏了?这里涉及到loss的选择了! masks输出的是N*H*W,这里有个问题:比如识别image的行人。这个数据集中有大量的行人图像,只有少数不是行人的图像,所以类别之间的极端不平衡:行人类别(正样本)远多于非行人类别(负样本),所以SAM这里采用的是focal loss,而不是cross entropy,这里的公式如下:
这里最重要的参数就是gamma了!以
- gamma = 0 时,focal loss退化为标准的交叉熵损失。模型可能会很快学会正确分类那些易分类的行人图像,但对于少数非行人图像,由于它们在训练集中的比例很低,模型可能无法给予足够的关注,导致对这些难分类样本的识别性能较差
- gamma > 0 时,focal loss会减少对那些已经分类正确的、易分类的行人样本的损失贡献,而增加对分类错误的、难分类的非行人样本的损失贡献。这样,模型被迫更多地关注那些难分类的样本,从而提高了对少数类别的识别能力
举例:比如gamma=2,对于一个已经被模型以高概率正确分类的行人图像(Pt接近1),那么(1-Pt)^2接近0,这会显著减少该样本的损失贡献。相反,对于一个被错误分类的非行人图像,此时Pt接近0,但是(1-Pt)^2会比较大,这会增加该样本的损失贡献,促使模型调整参数以更好地分类这类样本
上述loss是针对masks的,核心是某个pixel的类别对不对,类似于maskformer的binary mask loss!然而mask segmentation是由大量的pixel组成的,单个pixel分类正确了还不足以说明整个mask是正确的,咋办咧?还有另一个指标来评判maks整体的准确率有多高,就是IOU score(类似于maskformer的mask classification),图示如下:
核心就是predicted mask和ground truth mask之间的交集除以并集!score越高,说明predicted mask越接近ground truth mask!用hugging face的https://huggingface.co/facebook/sam-vit-huge训练好的model尝试,对车胎标记结果如下:
明显是mask2的效果最好,因为score最高的嘛!
(3)从SAM的网络结构上看,创新点有:
- image to token attn:之前都是token主动和image做cross attn,这里增加了image主动和token做cross attn
- 2x conv. trans做upsample,把image还原成原始的尺寸,利于mask segmentation的生成!
其他部分和maskformer基本一样,没啥本质区别了!
4、SAM效果好的另一个核心原因: data engine! 让人工标注大量数据的成本是很高的,怎么低成本地得到大量的优质标注数据了?论文原话:Our data engine has three stages: assisted-manual, semi-automatic, and fully automatic.
- In the first stage, SAM assists annotators in annotating masks, similar to a classic interactive segmentation setup.
- In the second stage, SAM can automatically generate masks for a subset of objects by prompting it with likely object locations and annotators focus on annotating the remaining objects, helping increase mask diversity.
- In the final stage, we prompt SAM with a regular grid of foreground points, yielding on average ∼100 high-quality masks per image.
概括一下这个所谓data engine迭代的原理其实很简单:
- 先用少量人工标注的高质量数据训练SAM。因为数据量少,所以此时SAM的质量很粗糙,不咋地
- 利用粗糙的SAM分割数据,但此时分割的结果质量肯定也不咋地,还是需要人工介入修正分割错误的数据,来提升准确率;注意,这步是修正,成本比人工从头开始标注低,这是关键点;这一步的目的是提升accuracy!
- 用修正好的增量数据继续训练模型,此时SAM的质量肯定比第一步的好很多。继续用SAM分割新image,再次用人工修正,找到漏掉的mask,这一步的目的是提升recall!
- 继续用上一步修正好的数据训练SAM,此时找到score较高的mask,继续进一步训练
- 重复上述的3、4两个步骤,就能得到越来越多的高质量标注数据啦!
用论文原话:Our final dataset, SA-1B, includes more than 1B masks from 11M licensed and privacy-preserving images:11M张image中得到了1B个高质量的mask;这么多的mask,如果全让人工标注,成本岂不是要上天了?
参考:
2、https://github.com/facebookresearch/segment-anything https://huggingface.co/facebook/sam-vit-huge
- SAM 2 code: https://github.com/facebookresearch/segment-anything-2
- SAM 2 demo: https://sam2.metademolab.com/
- SAM 2 paper: https://arxiv.org/abs/2408.00714
4、https://arxiv.org/abs/2304.02643 Segment Anything
6、常见的特征/信息融合方式:
- concat:将低层特征(基元、原始pixel等)和高层、大感受野的特征(全局context信息)在深度维度上进行拼接,拼接后的特征图包含了更多的上下文信息和细节信息
- Unet:低层特征、高层特征拼接,既有基元信息,又有全局context
-
googleNet的interception:不同的conv提取不同感受野、颗粒度的特征
- 金字塔池化 feather pyramid netword:feather map做不同尺寸的pooling,生成多尺度特征图,然后concat
- attention:transform,计算token之间的距离作为权重,用于更新token的value;不同net之间信息传递、融合
- 相加:resNet,保留原始特征,也提供各种方式提取后的特征【核心是有大量的subnet供路由router】
- 相乘:突出重要feather,削弱次要feather;比如对应位置的feather值相乘,强者越强,弱者越弱,马太效应;保留细节的同时增强特征表达
- 反卷积