[源码梳理][机器学习]PointMAE项目源码梳理和修改思路
Point_MAE项目源码中类的功能梳理。笔者要对它大刀阔斧,但发现自己作为调包侠,对torch的底层很多细节不熟悉,所以细细梳理一下这个项目。
Encoder
- 输入点云(这里为patch),返回该点云的全局特征
- 简单的PointNet代码
- first_conv和second_conv将三维坐标逐步投影到encoder_dim维
- 通过池化,拼接再投影再池化,得到点云的全局特征,返回
Group
输入点,返回center和neighborhood,即patch
MLP
一个两层的多链接层,用于transformer
Attention
- 输入token序列,映射到三倍维度的空间,拆下来三个同维度的就是qkv。(这个操作有意思,既有共享参数,又有独立参数的参数量)
- q乘k的转置,得到attention score。这里也就是attention_map了(为了处理不同大小的点云,这里需要加上padding mask的操作,详见上一篇博客)
- softmax得到attentionmap,然后乘v,过一个fc,drop一下,返回。
- mask map得从Block类输入
Block
-
transformer模块,输入x,输出transformer之后的x,维度不变
-
mask map得从TransformerEncoder和Decoder输入
TransformerEncoder
编码器,过depth个block
TransformerDecoder
解码器,过block,然后norm,过输出头
MaskTransformer
mask掉一些token,剩下的输入到TransformerEncoder,位置信息的处理略
PointMAE
主类,就是论文的pipeline了。先用MaskTransformer,然后用TransformerDecoder。
都得提供mask map。
PointTransformer
用于下游任务的类,encoder接一个head。略