Seq2Seq基于attention的pytorch实现

Seq2Seq基于attention的pytorch实现

Seq2Seq(attention)的PyTorch实现_哔哩哔哩_bilibili

图解注意力机制https://wmathor.com/index.php/archives/1450/

https://wmathor.com/index.php/archives/1432/

注意力机制

首先下图是一个encoder结构

image-20221004152024916

这里把h1到的hm称之为output,把最后一个时刻的输出hm记作为s0,它们的值是相等的,接下来把s0和所有的hi做一个函数,得到一个值,这个函数可以理解计算它们的相似度

\[a_{i} = align(h_{i},s_{0}) \]

然后进行下列计算得到c0,然后对c0取平均值

image-20221004152740046

如果某个ai特别大,那么c0很大一部分来自于这个ai,说明对这个xi这个单词很重视

image-20221004153241084

接着,将c0,s0,以及初始时刻Decoder的输入做一个整合,然后加个隐藏层A‘,计算得到s1,继续计算s1与所有hi之间新的相关性ai,计算c1

image-20221004154352525

align函数是把hi和s0拼接在一块,乘一个矩阵w,通过激活函数tanh,再乘一个向量

image-20221004154941004

另一种版本

09 什么是注意力机制(Attention ) - 二十三岁的有德 - 博客园 (cnblogs.com)

导包

import torch
import torch.nn as nn
import torch.nn.functional as F

seq2seqEncoder

class Seq2SeqEncoder(nn.Module):
    def __init__(self,embedding_dim,hidden_size,source_vocab_size):
        super(Seq2SeqEncoder,self).__init__()
        
        self.lstm_layer = nn.LSTM(input_size=embedding_dim,
                                 hidden_size=hidden_size,
                                 batch_first=True)
        self.embedding_table = torch.nn.Embedding(source_vocab_size,embedding_dim)
        
    def forward(self,input_ids):
        # 这里的ids是多个id,所以会是三维的
        input_sequence = self.embedding_table(input_ids) # 3d tensor batch*source_length*embedding_dim
        output_states,(final_h,final_c) = self.lstm_layer(input_sequence)
        
        return output_states,final_h

注意力机制

image-20221005101741297

这里的k是可以理解为下面代码的encoder_states,这个encoder_states是encoder中所有的隐藏层的状态,还有就是k==v,也就是v也是encoder_states,Q可以理解为这里的decoder_state

class Seq2SeqAttentionMechanism(nn.Module):
    def __init__(self):
        super(Seq2SeqAttentionMechanism,self).__init__()
     
    # 单步执行 
    def forward(self,decoder_state_t,encoder_states):
        bs,source_length,hidden_size = encoder_states.shape
        
        # decoder_state是二维 batch*hidden,需要扩维
        decoder_state_t = decoder_state_t.unsqueeze(1)
        decoder_state_t =  torch.tile(decoder_state_t,(1,source_length,1))
        
        score = torch.sum(decoder_state_t * encoder_states,dim=-1) # bs*source_length
        
        attn_prob = F.softmax(score,dim=-1) # bs*source_length
        
        context = torch.sum(attn_prob.unsqueeze(-1)*encoder_states,1) # bs*hidden_size
        
        return attn_prob,context

seq2seqDecoder

class Seq2SeqDecoder(nn.Module):
    def __init__(self,embedding_dim,hidden_size,num_classes,target_vocab_size,start_id,end_id):
        super(Seq2SeqDecoder,self).__init__()
        
        # cell就是单步执行
        self.lstm_cell = torch.nn.LSTMCell(embedding_dim,hidden_size)
        self.proj_layer = nn.Linear(hidden_size*2,num_classes)
        self.attention_mechanism = Seq2SeqAttentionMechanism()
        self.num_classes = num_classes
        self.embedding_table = torch.nn.Embedding(target_vocab_size,embedding_dim)
        # 偏移id
        self.start_id = start_id
        self.end_id = end_id
     
    # 训练用
    def forward(self,shifed_target_ids,encoder_states):
        shifted_target = self.embedding_table(shifted_target_ids)
        
        bs,target_length,embedding_dim = shifted_target.shape
        bs,target_length,hidden_size = encoder_states.shape
        
        logits = torch.zeros(bs,target_length,self.num_classes)
        probs = torch.zeros(bs,target_length,source_length)
        
        for t in range(target_length):
            decoder_input_t  = shifted_target[:,t,:]
            if t == 0:
                h_t,c_t = self.lstm_cell(decoder_input_t)
            else:
                h_t,c_t = self.lstm_cell(decoder_input_t,(h_t,c_t))
                
            attn_prob,context = self.attention_mechanism(h_t,encoder_states)
            
            decoder_output = torch.cat((context,h_t),-1)
            logits[:,t,:] = self.proj_layer(decoder_output)
            probs[:,t,:] = attn_prob
        
        return probs,logits
    
    def inference(self,encoder_states):
        # 推理阶段
        target_id = self.start_id
        h_t = None
        result = []
        
        while True:
            decoder_input_t = self.embedding_table(target_id)
            if h_t is None:
                h_t,c_t = self.lstm_cell(decoder_input_t)
            else:
                h_t,c_t = self.lstm_cell(decoder_input_t,(h_t,c_t))
            
            atten_prob,context = self.attention_mechanism(h_t,encoder_states)
            
            decoder_output = torch.cat((context,h_t),-1)
            logits = self.proj_layer(decoder_output)
            
            # 上一刻预测的,作为下一时刻的输入
            target_id = torch.argmax(logits,-1)
            result.append(target_id)
            
            if torch.any(target_id == self.end_id):
                print('stop decoding')
                break
                
        predicted_ids = torch.stack(result,dim=0)
        
        return predicted_ids

Model

class Model(nn.Module):
    def __init__(self,embedding_dim,hidden_size,num_classes,
                source_vocab_size,target_vocab_size,start_id,end_id):
        super(Model,self).__init__()
        
        self.encoder = Seq2SeqEncoder(embedding_dim,hidden_size,source_vocab_size)
        
        self.decoder = Seq2SeqDecoder(embedding_dim,hidden_size,num_classes,
                                     target_vocab_size,start_id,end_id)
        
    def forward(self,inut_sequence_ids,shifted_target_ids):
        
        encoder_states,final_h = self.encoder(input_sequence_ids)
        
        probs,logits = self.decoder(shifted_target_ids,encoder_states)
        
        return probs,logits
    def ifer(self):
        pass

主函数

if __name__ == '__main__':
    source_length = 3
    target_length = 4
    embedding_dim = 8
    hidden_size = 16
    num_classes = 10
    bs = 2
    start_id = end_id = 0
    source_vocab_size = 100
    target_vocab_size = 100
    
    input_sequence_ids = torch.randint(source_vocab_size,size=(bs,source_length)).to(torch.int32)
    
    target_ids = torch.randint(target_vocab_size,size=(bs,target_length))
    target_ids = torch.cat((target_ids,end_id*torch.ones(bs,1)),dim=1).to(torch.int32)
    
    shifted_target_ids = torch.cat((start_id*torch.ones(bs,1),target_ids[:,1:]),dim=1).to(torch.int32)
    
    model = Model(embedding_dim,hidden_size,num_classes,source_vocab_size,target_vocab_size,start_id,end_id)
    probs,logits = model(input_sequence_ids,shifted_target_ids)
    print(probs.shape)
    print(logits.shape)

image-20221005124450833

posted @ 2022-10-19 12:05  放学别跑啊  阅读(185)  评论(0编辑  收藏  举报