Transformer代码细节
优化措施
作者采用warmup学习率,先线性增长学习率,随后指数缓慢减少学习率
class ScheduledOptim():
'''A simple wrapper class for learning rate scheduling'''
def __init__(self, optimizer, init_lr, d_model, n_warmup_steps):
self._optimizer = optimizer
self.init_lr = init_lr
self.d_model = d_model
self.n_warmup_steps = n_warmup_steps
self.n_steps = 0
def step_and_update_lr(self):
"Step with the inner optimizer"
self._update_learning_rate()
self._optimizer.step()
def zero_grad(self):
"Zero out the gradients with the inner optimizer"
self._optimizer.zero_grad()
def _get_lr_scale(self):
d_model = self.d_model
n_steps, n_warmup_steps = self.n_steps, self.n_warmup_steps
return (d_model ** -0.5) * min(n_steps ** (-0.5), n_steps * n_warmup_steps ** (-1.5))
def _update_learning_rate(self):
''' Learning rate scheduling per step '''
self.n_steps += 1
lr = self.init_lr * self._get_lr_scale()
for param_group in self._optimizer.param_groups:
param_group['lr'] = lr
optimizer = ScheduledOptim(
optim.Adam(transformer.parameters(), betas=(0.9, 0.98), eps=1e-09),
2.0, opt.d_model, opt.n_warmup_steps) #学习率可以设置的这么高吗?
标签平滑
对于原始标签的one-hot向量[1,0,0]变为[1-0.1,0.05,0.05]其中\(\epsilon = 0.1\)
def cal_loss(pred, gold, trg_pad_idx, smoothing=False):
''' Calculate cross entropy loss, apply label smoothing if needed. '''
gold = gold.contiguous().view(-1)
if smoothing:
eps = 0.1
n_class = pred.size(1)
one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, dim=1)
non_pad_mask = gold.ne(trg_pad_idx)
loss = -(one_hot * log_prb).sum(dim=1)
loss = loss.masked_select(non_pad_mask).sum() # average later
else:
loss = F.cross_entropy(pred, gold, ignore_index=trg_pad_idx, reduction='sum')
return loss
生成mask矩阵
#0的位置返回False
def get_pad_mask(seq, pad_idx):
return (seq != pad_idx).unsqueeze(-2)
def get_subsequent_mask(seq):
#返回下三角矩阵,上三角矩阵部分全为False
''' For masking out the subsequent info. '''
sz_b, len_s = seq.size()
subsequent_mask = (1 - torch.triu(
torch.ones((1, len_s, len_s), device=seq.device), diagonal=1)).bool()
return subsequent_mask #返回一个下三角矩阵
位置编码
class PositionalEncoding(nn.Module):
def __init__(self, d_hid, n_position=200):
super(PositionalEncoding, self).__init__()
# Not a parameter optim.step不更新参数
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
def _get_sinusoid_encoding_table(self, n_position, d_hid):
''' Sinusoid position encoding table '''
# TODO: make it with torch instead of numpy
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
def forward(self, x):
return x + self.pos_table[:, :x.size(1)].clone().detach()
MultiHead Attention
class ScaledDotProductAttention(nn.Module):
''' Scaled Dot-Product Attention '''
def __init__(self, temperature, attn_dropout=0.1):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
def forward(self, q, k, v, mask=None):
attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
if mask is not None:
attn = attn.masked_fill(mask == 0, -1e9)
attn = self.dropout(F.softmax(attn, dim=-1))
output = torch.matmul(attn, v)
return output, attn
class MultiHeadAttention(nn.Module):
''' Multi-Head Attention module '''
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) #512*512
self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
def forward(self, q, k, v, mask=None):
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
residual = q
# Pass through the pre-attention projection: b x lq x (n*dv)
# Separate different heads: b x lq x n x dv
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
# Transpose for attention dot product: b x n x lq x dv
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
if mask is not None:
mask = mask.unsqueeze(1) # For head axis broadcasting.
q, attn = self.attention(q, k, v, mask=mask)
# Transpose to move the head dimension back: b x lq x n x dv
# Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
q = self.dropout(self.fc(q))
q += residual
q = self.layer_norm(q)
return q, attn
PositionwiseFeedForward
class PositionwiseFeedForward(nn.Module):
''' A two-feed-forward-layer module '''
def __init__(self, d_in, d_hid, dropout=0.1):
super().__init__()
self.w_1 = nn.Linear(d_in, d_hid) # position-wise
self.w_2 = nn.Linear(d_hid, d_in) # position-wise
self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
residual = x
x = self.w_2(F.relu(self.w_1(x)))
x = self.dropout(x)
x += residual
x = self.layer_norm(x)
return x
Transformer的Encoder和Decoder端
class Encoder(nn.Module):
''' A encoder model with self attention mechanism. '''
def __init__(
self, n_src_vocab, d_word_vec, n_layers, n_head, d_k, d_v,
d_model, d_inner, pad_idx, dropout=0.1, n_position=200):
super().__init__()
self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=pad_idx)
self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
self.dropout = nn.Dropout(p=dropout)
self.layer_stack = nn.ModuleList([
EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)])
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
def forward(self, src_seq, src_mask, return_attns=False):
enc_slf_attn_list = []
# -- Forward
enc_output = self.dropout(self.position_enc(self.src_word_emb(src_seq)))
enc_output = self.layer_norm(enc_output) #在embedding和位置编码后也进行一次Layer_Norm
for enc_layer in self.layer_stack:
enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask)
enc_slf_attn_list += [enc_slf_attn] if return_attns else []
if return_attns:
return enc_output, enc_slf_attn_list
return enc_output,
class Decoder(nn.Module):
''' A decoder model with self attention mechanism. '''
def __init__(
self, n_trg_vocab, d_word_vec, n_layers, n_head, d_k, d_v,
d_model, d_inner, pad_idx, n_position=200, dropout=0.1):
super().__init__()
self.trg_word_emb = nn.Embedding(n_trg_vocab, d_word_vec, padding_idx=pad_idx)
self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
self.dropout = nn.Dropout(p=dropout)
self.layer_stack = nn.ModuleList([
DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)])
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
def forward(self, trg_seq, trg_mask, enc_output, src_mask, return_attns=False):
dec_slf_attn_list, dec_enc_attn_list = [], []
# -- Forward
dec_output = self.dropout(self.position_enc(self.trg_word_emb(trg_seq)))
dec_output = self.layer_norm(dec_output)
for dec_layer in self.layer_stack:
dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
dec_output, enc_output, slf_attn_mask=trg_mask, dec_enc_attn_mask=src_mask)
dec_slf_attn_list += [dec_slf_attn] if return_attns else []
dec_enc_attn_list += [dec_enc_attn] if return_attns else []
if return_attns:
return dec_output, dec_slf_attn_list, dec_enc_attn_list
return dec_output,
class Transformer(nn.Module):
''' A sequence to sequence model with attention mechanism. '''
def __init__(
self, n_src_vocab, n_trg_vocab, src_pad_idx, trg_pad_idx,
d_word_vec=512, d_model=512, d_inner=2048,
n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1, n_position=200,
trg_emb_prj_weight_sharing=True, emb_src_trg_weight_sharing=True):
super().__init__()
self.src_pad_idx, self.trg_pad_idx = src_pad_idx, trg_pad_idx
self.encoder = Encoder(
n_src_vocab=n_src_vocab, n_position=n_position,
d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner,
n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,
pad_idx=src_pad_idx, dropout=dropout)
self.decoder = Decoder(
n_trg_vocab=n_trg_vocab, n_position=n_position,
d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner,
n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,
pad_idx=trg_pad_idx, dropout=dropout)
self.trg_word_prj = nn.Linear(d_model, n_trg_vocab, bias=False)
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p) #xavier初始化
assert d_model == d_word_vec, \
'To facilitate the residual connections, \
the dimensions of all module outputs shall be the same.'
self.x_logit_scale = 1.
if trg_emb_prj_weight_sharing: #Decoder的pre-softmax层和Decoder端的Embedding共享权重
# Share the weight between target word embedding & last dense layer
self.trg_word_prj.weight = self.decoder.trg_word_emb.weight
self.x_logit_scale = (d_model ** -0.5)# 为什么这里要加个缩放因子?
if emb_src_trg_weight_sharing: #Encoder和Decoder的Embedding矩阵相同
self.encoder.src_word_emb.weight = self.decoder.trg_word_emb.weight
def forward(self, src_seq, trg_seq):
src_mask = get_pad_mask(src_seq, self.src_pad_idx) #src_seq的维度为[batch_size,seq_len]
trg_mask = get_pad_mask(trg_seq, self.trg_pad_idx) & get_subsequent_mask(trg_seq)
enc_output, *_ = self.encoder(src_seq, src_mask)
dec_output, *_ = self.decoder(trg_seq, trg_mask, enc_output, src_mask)
seq_logit = self.trg_word_prj(dec_output) * self.x_logit_scale
return seq_logit.view(-1, seq_logit.size(2))
beamsearch部分
设置beam_size=5,\(\alpha = 0.7\),\(\alpha\)是一个惩罚系数,S(Y|X)=Score(Y|X)/(seq_len**alpha)
class Translator(nn.Module):
''' Load a trained model and translate in beam search fashion. '''
def __init__(
self, model, beam_size, max_seq_len,
src_pad_idx, trg_pad_idx, trg_bos_idx, trg_eos_idx):
super(Translator, self).__init__()
self.alpha = 0.7
self.beam_size = beam_size
self.max_seq_len = max_seq_len
self.src_pad_idx = src_pad_idx
self.trg_bos_idx = trg_bos_idx
self.trg_eos_idx = trg_eos_idx
self.model = model
self.model.eval() #预测阶段
self.register_buffer('init_seq', torch.LongTensor([[trg_bos_idx]]))
self.register_buffer(
'blank_seqs',
torch.full((beam_size, max_seq_len), trg_pad_idx, dtype=torch.long))
self.blank_seqs[:, 0] = self.trg_bos_idx
self.register_buffer(
'len_map',
torch.arange(1, max_seq_len + 1, dtype=torch.long).unsqueeze(0))
def _model_decode(self, trg_seq, enc_output, src_mask):
trg_mask = get_subsequent_mask(trg_seq)
dec_output, *_ = self.model.decoder(trg_seq, trg_mask, enc_output, src_mask)
return F.softmax(self.model.trg_word_prj(dec_output), dim=-1)
def _get_init_state(self, src_seq, src_mask):
beam_size = self.beam_size
enc_output, *_ = self.model.encoder(src_seq, src_mask) #[1,seq_len,512]
dec_output = self._model_decode(self.init_seq, enc_output, src_mask)
best_k_probs, best_k_idx = dec_output[:, -1, :].topk(beam_size) #得到第一个解码的beam_size词表 [1*beam_size],此时的batch_size为1
scores = torch.log(best_k_probs).view(beam_size)
gen_seq = self.blank_seqs.clone().detach() #[beam_size,max_seq_len]
gen_seq[:, 1] = best_k_idx[0]
enc_output = enc_output.repeat(beam_size, 1, 1) #[beam_size,seq_len,512]
return enc_output, gen_seq, scores
def _get_the_best_score_and_idx(self, gen_seq, dec_output, scores, step):
assert len(scores.size()) == 1
beam_size = self.beam_size
# Get k candidates for each beam, k^2 candidates in total.
best_k2_probs, best_k2_idx = dec_output[:, -1, :].topk(beam_size)
# Include the previous scores.
scores = torch.log(best_k2_probs).view(beam_size, -1) + scores.view(beam_size, 1)
# Get the best k candidates from k^2 candidates.
scores, best_k_idx_in_k2 = scores.view(-1).topk(beam_size)
# Get the corresponding positions of the best k candidiates.
best_k_r_idxs, best_k_c_idxs = best_k_idx_in_k2 // beam_size, best_k_idx_in_k2 % beam_size
best_k_idx = best_k2_idx[best_k_r_idxs, best_k_c_idxs]
# Copy the corresponding previous tokens.
gen_seq[:, :step] = gen_seq[best_k_r_idxs, :step]
# Set the best tokens in this beam search step
gen_seq[:, step] = best_k_idx
return gen_seq, scores
def translate_sentence(self, src_seq):
# Only accept batch size equals to 1 in this function.
# TODO: expand to batch operation.
assert src_seq.size(0) == 1
src_pad_idx, trg_eos_idx = self.src_pad_idx, self.trg_eos_idx
max_seq_len, beam_size, alpha = self.max_seq_len, self.beam_size, self.alpha
with torch.no_grad():
src_mask = get_pad_mask(src_seq, src_pad_idx)
enc_output, gen_seq, scores = self._get_init_state(src_seq, src_mask)
ans_idx = 0 # default
for step in range(2, max_seq_len): # decode up to max length
dec_output = self._model_decode(gen_seq[:, :step], enc_output, src_mask) #[beam_size,vocab_size]
gen_seq, scores = self._get_the_best_score_and_idx(gen_seq, dec_output, scores, step)
# Check if all path finished
# -- locate the eos in the generated sequences
eos_locs = gen_seq == trg_eos_idx
# -- replace the eos with its position for the length penalty use
seq_lens, _ = self.len_map.masked_fill(~eos_locs, max_seq_len).min(1)
# -- check if all beams contain eos
if (eos_locs.sum(1) > 0).sum(0).item() == beam_size: #遇到终止符
# TODO: Try different terminate conditions.
_, ans_idx = scores.div(seq_lens.float() ** alpha).max(0)
ans_idx = ans_idx.item()
break
return gen_seq[ans_idx][:seq_lens[ans_idx]].tolist()
函数调用从translate_sentence开始