Transformer 中的 attention
Transformer 中的 attention
大火的transformer 本质就是:
*使用attention机制的seq2seq。*
所以它的核心就是attention机制,今天就讲attention。直奔代码VIT-pytorch:
https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
中的
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
看吧!就是这么简单。今天就彻底搞懂这个东西。
先记住attention的这么几个点:
- attention和CNN、RNN、FC、GCN等都是一个级别的东西,用来提取特征;既然是特征提取,一定有权重(W+B)存在。
- attention的优点:可以像CNN一样并行运算 + 像RNN一样通过一层就拥有全局资讯。有一个东西也可以做到,那就是FC,但是FC有个弱点:对输入尺寸有限制,说白了不好适应可变输入数据,这对于序列无疑是非常不友好的。
- pooling也可以实现,但是它是无参的过程。例如点云数据,就可以用pooling来处理,当然也有一些网络是pooling is all your need。
- 可以像CNN一样并行运算 ,其实CNN运算也是通过im2col或winograd等转化为矩阵运算的。
- RNN不能并行,所以通常它处理的数据有“时序”这个特点。既然是“时序”,那么就不是同一个时刻完成的,所以不能并行化。
综上所述: attention优点 = CNN并行+RNN全局资讯+对输入尺寸(时序长度维度上)没有限制。
如果你能创造一个拥有上面三点优点的东西出来,你也可以引领潮流。
然后回到代码,再熟悉这么几个设置:
- batch维度:大家利用同样的权重和操作提取特征,可以理解为for循环式,相互之间没有信息交互;
- multi head维度:同batch类似,不过是利用的不同权重和相同操作提取特征,最后concate一起使用;
- FC层:是作用在每一个特征上,类似CNN中的1X1,可以叫“pointwise”,和序列长度没有关系;因为序列中所有的特征经过的是同一个FC。
下面看这个图,看完不懂的可以扇自己了:
attention的顺序是:
- 你有长度为n(序列)的序列,每个元素都是一个特征,每个特征都是一个向量;
- 每个向量都经过FC1,FC2,FC3获取到q,k,v三个向量(长度自己定),记住,不同特征用的是同一个FC1,FC2,FC3。可以说对于一个head,就一组FC1,FC2,FC3。
- 特征1的q1和所有特征的k 进行点乘,获取一串值,注意:和自己的k也进行点乘;点乘向量变标量,表示相似性。多个K可不就是一串标量。
- 3中的那一串值进行softmax操作,作为权重 对所有v加权求和,获得特征1输出;
- 其他所有的特征和特征1的操作一样,注意所有特征是一块并行计算的;
- 最后获取的和输入一样长度的特征序列再经过FC进行长度(特征维度)调整,也可以不要;
对了,softmax之前不要忘记 除以 qkv长度开方进行scaled,其实就是标准化操作(我觉得可以理解为各种N(BN,GN,LN等))。
就是这么简单,你学会了吗?