[源码梳理][机器学习]PointMAE项目源码梳理和修改思路

Point_MAE项目源码中类的功能梳理。笔者要对它大刀阔斧,但发现自己作为调包侠,对torch的底层很多细节不熟悉,所以细细梳理一下这个项目。

链接

Encoder

  • 输入点云(这里为patch),返回该点云的全局特征
  • 简单的PointNet代码
  • first_conv和second_conv将三维坐标逐步投影到encoder_dim维
  • 通过池化,拼接再投影再池化,得到点云的全局特征,返回

Group

输入点,返回center和neighborhood,即patch

MLP

一个两层的多链接层,用于transformer

Attention

  • 输入token序列,映射到三倍维度的空间,拆下来三个同维度的就是qkv。(这个操作有意思,既有共享参数,又有独立参数的参数量)
  • q乘k的转置,得到attention score。这里也就是attention_map了(为了处理不同大小的点云,这里需要加上padding mask的操作,详见上一篇博客)
  • softmax得到attentionmap,然后乘v,过一个fc,drop一下,返回。
  • mask map得从Block类输入

Block

  • transformer模块,输入x,输出transformer之后的x,维度不变

  • mask map得从TransformerEncoder和Decoder输入

TransformerEncoder

编码器,过depth个block

TransformerDecoder

解码器,过block,然后norm,过输出头

MaskTransformer

mask掉一些token,剩下的输入到TransformerEncoder,位置信息的处理略

PointMAE

主类,就是论文的pipeline了。先用MaskTransformer,然后用TransformerDecoder。

都得提供mask map。

PointTransformer

用于下游任务的类,encoder接一个head。略

posted @ 2023-08-15 15:41  溡沭  阅读(335)  评论(0编辑  收藏  举报