Seq2Seq基于attention的pytorch实现
Seq2Seq基于attention的pytorch实现
Seq2Seq(attention)的PyTorch实现_哔哩哔哩_bilibili
图解注意力机制https://wmathor.com/index.php/archives/1450/
https://wmathor.com/index.php/archives/1432/
注意力机制
首先下图是一个encoder结构
这里把h1到的hm称之为output,把最后一个时刻的输出hm记作为s0,它们的值是相等的,接下来把s0和所有的hi做一个函数,得到一个值,这个函数可以理解计算它们的相似度
\[a_{i} = align(h_{i},s_{0})
\]
然后进行下列计算得到c0,然后对c0取平均值
如果某个ai特别大,那么c0很大一部分来自于这个ai,说明对这个xi这个单词很重视
接着,将c0,s0,以及初始时刻Decoder的输入做一个整合,然后加个隐藏层A‘,计算得到s1,继续计算s1与所有hi之间新的相关性ai,计算c1
align函数是把hi和s0拼接在一块,乘一个矩阵w,通过激活函数tanh,再乘一个向量
另一种版本
09 什么是注意力机制(Attention ) - 二十三岁的有德 - 博客园 (cnblogs.com)
导包
import torch
import torch.nn as nn
import torch.nn.functional as F
seq2seqEncoder
class Seq2SeqEncoder(nn.Module):
def __init__(self,embedding_dim,hidden_size,source_vocab_size):
super(Seq2SeqEncoder,self).__init__()
self.lstm_layer = nn.LSTM(input_size=embedding_dim,
hidden_size=hidden_size,
batch_first=True)
self.embedding_table = torch.nn.Embedding(source_vocab_size,embedding_dim)
def forward(self,input_ids):
# 这里的ids是多个id,所以会是三维的
input_sequence = self.embedding_table(input_ids) # 3d tensor batch*source_length*embedding_dim
output_states,(final_h,final_c) = self.lstm_layer(input_sequence)
return output_states,final_h
注意力机制
这里的k是可以理解为下面代码的encoder_states,这个encoder_states是encoder中所有的隐藏层的状态,还有就是k==v,也就是v也是encoder_states,Q可以理解为这里的decoder_state
class Seq2SeqAttentionMechanism(nn.Module):
def __init__(self):
super(Seq2SeqAttentionMechanism,self).__init__()
# 单步执行
def forward(self,decoder_state_t,encoder_states):
bs,source_length,hidden_size = encoder_states.shape
# decoder_state是二维 batch*hidden,需要扩维
decoder_state_t = decoder_state_t.unsqueeze(1)
decoder_state_t = torch.tile(decoder_state_t,(1,source_length,1))
score = torch.sum(decoder_state_t * encoder_states,dim=-1) # bs*source_length
attn_prob = F.softmax(score,dim=-1) # bs*source_length
context = torch.sum(attn_prob.unsqueeze(-1)*encoder_states,1) # bs*hidden_size
return attn_prob,context
seq2seqDecoder
class Seq2SeqDecoder(nn.Module):
def __init__(self,embedding_dim,hidden_size,num_classes,target_vocab_size,start_id,end_id):
super(Seq2SeqDecoder,self).__init__()
# cell就是单步执行
self.lstm_cell = torch.nn.LSTMCell(embedding_dim,hidden_size)
self.proj_layer = nn.Linear(hidden_size*2,num_classes)
self.attention_mechanism = Seq2SeqAttentionMechanism()
self.num_classes = num_classes
self.embedding_table = torch.nn.Embedding(target_vocab_size,embedding_dim)
# 偏移id
self.start_id = start_id
self.end_id = end_id
# 训练用
def forward(self,shifed_target_ids,encoder_states):
shifted_target = self.embedding_table(shifted_target_ids)
bs,target_length,embedding_dim = shifted_target.shape
bs,target_length,hidden_size = encoder_states.shape
logits = torch.zeros(bs,target_length,self.num_classes)
probs = torch.zeros(bs,target_length,source_length)
for t in range(target_length):
decoder_input_t = shifted_target[:,t,:]
if t == 0:
h_t,c_t = self.lstm_cell(decoder_input_t)
else:
h_t,c_t = self.lstm_cell(decoder_input_t,(h_t,c_t))
attn_prob,context = self.attention_mechanism(h_t,encoder_states)
decoder_output = torch.cat((context,h_t),-1)
logits[:,t,:] = self.proj_layer(decoder_output)
probs[:,t,:] = attn_prob
return probs,logits
def inference(self,encoder_states):
# 推理阶段
target_id = self.start_id
h_t = None
result = []
while True:
decoder_input_t = self.embedding_table(target_id)
if h_t is None:
h_t,c_t = self.lstm_cell(decoder_input_t)
else:
h_t,c_t = self.lstm_cell(decoder_input_t,(h_t,c_t))
atten_prob,context = self.attention_mechanism(h_t,encoder_states)
decoder_output = torch.cat((context,h_t),-1)
logits = self.proj_layer(decoder_output)
# 上一刻预测的,作为下一时刻的输入
target_id = torch.argmax(logits,-1)
result.append(target_id)
if torch.any(target_id == self.end_id):
print('stop decoding')
break
predicted_ids = torch.stack(result,dim=0)
return predicted_ids
Model
class Model(nn.Module):
def __init__(self,embedding_dim,hidden_size,num_classes,
source_vocab_size,target_vocab_size,start_id,end_id):
super(Model,self).__init__()
self.encoder = Seq2SeqEncoder(embedding_dim,hidden_size,source_vocab_size)
self.decoder = Seq2SeqDecoder(embedding_dim,hidden_size,num_classes,
target_vocab_size,start_id,end_id)
def forward(self,inut_sequence_ids,shifted_target_ids):
encoder_states,final_h = self.encoder(input_sequence_ids)
probs,logits = self.decoder(shifted_target_ids,encoder_states)
return probs,logits
def ifer(self):
pass
主函数
if __name__ == '__main__':
source_length = 3
target_length = 4
embedding_dim = 8
hidden_size = 16
num_classes = 10
bs = 2
start_id = end_id = 0
source_vocab_size = 100
target_vocab_size = 100
input_sequence_ids = torch.randint(source_vocab_size,size=(bs,source_length)).to(torch.int32)
target_ids = torch.randint(target_vocab_size,size=(bs,target_length))
target_ids = torch.cat((target_ids,end_id*torch.ones(bs,1)),dim=1).to(torch.int32)
shifted_target_ids = torch.cat((start_id*torch.ones(bs,1),target_ids[:,1:]),dim=1).to(torch.int32)
model = Model(embedding_dim,hidden_size,num_classes,source_vocab_size,target_vocab_size,start_id,end_id)
probs,logits = model(input_sequence_ids,shifted_target_ids)
print(probs.shape)
print(logits.shape)