vit的线性映射过程

Vision Transformer 线性映射

Vision Transformer (ViT): 线性映射

1. 展平图像块

假设输入的图像块大小为 P × P 像素,并且图像有 C 个通道(对于RGB图像,通常 C = 3)。

每个图像块被展平成一个向量,向量的维度为 P × P × C 。

例如,对于一个16x16像素的RGB图像块,展平后的向量长度为 16 × 16 × 3 = 768 。

2. 线性映射的目的

线性映射的目标是将这个展平后的向量映射到一个新的空间,该空间的维度通常与Transformer模型的隐藏层维度一致,记作 D 。

常见的 D 的选择有768或1024等。

3. 线性映射的实现方式

线性映射可以理解为一个简单的全连接层。注意这个全连接层是需要训练的

对于每个展平后的图像块向量 x(长度为 P × P × C ),线性映射通过矩阵乘法来完成:

z = x · W + b

  • x 是输入的展平向量,维度为 P × P × C 。
  • W 是线性映射的权重矩阵,维度为 (P × P × C) × D 。
  • b 是偏置向量,维度为 D 。
  • z 是输出的嵌入向量,维度为 D 。

4. 高效实现

由于每个图像块的向量是独立处理的,因此可以使用矩阵运算批量处理所有图像块,利用并行计算加速训练和推理过程。

在实际实现中,这个线性映射通常通过深度学习框架中的全连接层(如PyTorch中的 nn.Linear 或 TensorFlow中的 Dense)来实现。

5. 结果

每个图像块被转换为一个维度为 D 的嵌入向量,这些嵌入向量与位置编码相加后,作为Transformer的输入。

6. 简化公式

这里是简化后的公式:

  • 输入展平向量:x (长度 P × P × C )
  • 输出嵌入向量:z (长度 D )
  • 权重矩阵:W (维度 (P × P × C) × D )
  • 线性映射公式:z = x · W + b

通过这个线性映射,ViT能够将原始图像块转换为具有更高表达能力的向量表示,使得Transformer能够有效地处理和学习图像数据的特征。

posted @ 2024-08-09 22:46  海_纳百川  阅读(43)  评论(0编辑  收藏  举报
本站总访问量