Transformer 中的 attention

Transformer 中的 attention

转自Transformer中的attention,看完不懂扇我脸

大火的transformer 本质就是:

*使用attention机制的seq2seq。*

所以它的核心就是attention机制,今天就讲attention。直奔代码VIT-pytorch:

https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py

中的

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

看吧!就是这么简单。今天就彻底搞懂这个东西。

先记住attention的这么几个点:

  • attention和CNN、RNN、FC、GCN等都是一个级别的东西,用来提取特征;既然是特征提取,一定有权重(W+B)存在。
  • attention的优点:可以像CNN一样并行运算 + 像RNN一样通过一层就拥有全局资讯。有一个东西也可以做到,那就是FC,但是FC有个弱点:对输入尺寸有限制,说白了不好适应可变输入数据,这对于序列无疑是非常不友好的。
  • pooling也可以实现,但是它是无参的过程。例如点云数据,就可以用pooling来处理,当然也有一些网络是pooling is all your need。
  • 可以像CNN一样并行运算 ,其实CNN运算也是通过im2col或winograd等转化为矩阵运算的。
  • RNN不能并行,所以通常它处理的数据有“时序”这个特点。既然是“时序”,那么就不是同一个时刻完成的,所以不能并行化。

综上所述: attention优点 = CNN并行+RNN全局资讯+对输入尺寸(时序长度维度上)没有限制。

如果你能创造一个拥有上面三点优点的东西出来,你也可以引领潮流。

然后回到代码,再熟悉这么几个设置:

  • batch维度:大家利用同样的权重和操作提取特征,可以理解为for循环式,相互之间没有信息交互;
  • multi head维度:同batch类似,不过是利用的不同权重和相同操作提取特征,最后concate一起使用;
  • FC层:是作用在每一个特征上,类似CNN中的1X1,可以叫“pointwise”,和序列长度没有关系;因为序列中所有的特征经过的是同一个FC。

下面看这个图,看完不懂的可以扇自己了:

attention的顺序是:

  1. 你有长度为n(序列)的序列,每个元素都是一个特征,每个特征都是一个向量;
  2. 每个向量都经过FC1,FC2,FC3获取到q,k,v三个向量(长度自己定),记住,不同特征用的是同一个FC1,FC2,FC3。可以说对于一个head,就一组FC1,FC2,FC3。
  3. 特征1的q1和所有特征的k 进行点乘,获取一串值,注意:和自己的k也进行点乘;点乘向量变标量,表示相似性。多个K可不就是一串标量。
  4. 3中的那一串值进行softmax操作,作为权重 对所有v加权求和,获得特征1输出;
  5. 其他所有的特征和特征1的操作一样,注意所有特征是一块并行计算的;
  6. 最后获取的和输入一样长度的特征序列再经过FC进行长度(特征维度)调整,也可以不要;

对了,softmax之前不要忘记 除以 qkv长度开方进行scaled,其实就是标准化操作(我觉得可以理解为各种N(BN,GN,LN等))。

就是这么简单,你学会了吗?

posted @ 2022-05-08 11:49  梁君牧  阅读(109)  评论(0编辑  收藏  举报