Transformer代码详解: attention-is-all-you-need-pytorch
from: https://zhuanlan.zhihu.com/p/463052305
参考:
attention-is-all-you-need-pytorch
Transformer代码详解-pytorch版
Transformer模型结构
Transformer模型结构如下图:
- Transformer的整体结构就是分成Encoder和Decoder两部分,并且两部分之间是有联系的,可以注意到Encoder的输出是Decoder第二个Multi-head Attention中和的输入。
- Encoder和Decoder分别由N个EncoderLayer和DecoderLayer组成。N默认为6个。
- EncoderLayer由两个SubLayers组成,分别是Multi-head Attention和Feed Forward。DecoderLayer则是由三个SubLayers组成,分别是Masked Multi-head Attention,Multi-head Attention和Feed Forward。
- Multi-head Attention是用ScaledDotProductAttention和Linear组成。Feed Forward是由Linear组成。
- Add & Norm指的是残差连接之后再进行LayerNorm。
各模块结构结构
Multi-head Attention结构
Feed Forward结构
EncoderLayer结构
DecoderLayer结构
Encoder结构
Decoder结构
ScaledDotProductAttention模块
ScaledDotProductAttention做的是一个attention计算。公式如下:
输入q k v,可以q先除以根号d_k(d_k默认为64,根号d_k就为8),再与k的转置相乘,再经过softmax,最后与v相乘。下图的操作和公式所做的东西是一样的。
class ScaledDotProductAttention(nn.Module):
''' Scaled Dot-Product Attention '''
def __init__(self, temperature, attn_dropout=0.1):
super().__init__()
# 其实就是论文中的根号d_k
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
def forward(self, q, k, v, mask=None):
# sz_b: batch_size 批量大小
# len_q,len_k,len_v: 序列长度 在这里他们都相等
# n_head: 多头注意力 默认为8
# d_k,d_v: k v 的dim(维度) 默认都是64
# 此时q的shape为(sz_b, n_head, len_q, d_k) (sz_b, 8, len_q, 64)
# 此时k的shape为(sz_b, n_head, len_k, d_k) (sz_b, 8, len_k, 64)
# 此时v的shape为(sz_b, n_head, len_k, d_v) (sz_b, 8, len_k, 64)
# q先除以self.temperature(论文中的根号d_k) k交换最后两个维度(这样才可以进行矩阵相乘) 最后两个张量进行矩阵相乘
# attn的shape为(sz_b, n_head, len_q, len_k)
attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
if mask is not None:
# 用-1e9代替0 -1e9是一个很大的负数 经过softmax之后接近与0
# 其一:去除掉各种padding在训练过程中的影响
# 其二,将输入进行遮盖,避免decoder看到后面要预测的东西。(只用在decoder中)
attn = attn.masked_fill(mask == 0, -1e9)
# 先在attn的最后一个维度做softmax 再dropout 得到注意力分数
attn = self.dropout(F.softmax(attn, dim=-1))
# 最后attn与v进行矩阵相乘
# output的shape为(sz_b, 8, len_q, 64)
output = torch.matmul(attn, v)
# 返回 output和注意力分数
return output, attn
MultiHeadAttention和PositionwiseFeedForward模块
MultiHeadAttention做的是将q k v先经过线性层投影,再做ScaledDotProductAttention ,最后经过一个线性层。也就是下图的操作:
对应着Transformer的模块是:
PositionwiseFeedForward其实就是MLP。对应着Transformer的模块是:
# q k v 先经过不同的线性层 再用ScaledDotProductAttention 最后再经过一个线性层
class MultiHeadAttention(nn.Module):
''' Multi-Head Attention module '''
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
# 这里的n_head, d_model, d_k, d_v分别默认为8, 512, 64, 64
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
def forward(self, q, k, v, mask=None):
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
# len_q, len_k, len_v 为输入的序列长度
sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
# 用作残差连接
residual = q
# Pass through the pre-attention projection: b x lq x (n*dv)
# Separate different heads: b x lq x n x dv
# q k v 分别经过一个线性层再改变维度
# 由(sz_b, len_q, n_head*d_k) => (sz_b, len_q, n_head, d_k) (sz_b, len_q, 8*64) => (sz_b, len_q, 8, 64)
q = self.w_qs(q).view(sz_b, len_q,