transformer的Pytorch简易实现

Transformer(Pytorch) from scratch

code by Tae Hwan Jung(Jeff Jung) @graykode, Derek Miller @dmmiller612, modified by shwei; modified again by LittleHenry

Reference:
https://blog.csdn.net/BXD1314/article/details/126187598?spm=1001.2014.3001.5506
https://blog.csdn.net/BXD1314/article/details/125759352?spm=1001.2014.3001.5502
https://github.com/jadore801120/attention-is-all-you-need-pytorch
https://github.com/JayParks/transformer

import math
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data

数据构建

# 训练集

# S: Symbol that shows starting of decoding input
# E: Symbol that shows starting of decoding output
# P: Symbol that will fill in blank sequence if current batch data size is short than time steps

sentences = [
    # 中文和英语的单词个数不要求相同
    # enc_input                 dec_input                  dec_output
    ['我 有 一 个 女 朋 友 。 P', 'S I have a girl friend .', 'I have a girl friend . E'],
    ['我 有 一 个 好 朋 友 。 P', 'S I have a good friend .', 'I have a good friend . E'],
    ['我 有 一 个 男 朋 友 。 P', 'S I have a boy friend .', 'I have a boy friend . E'],
    ['我 有 零 个 女 朋 友 。 P', 'S I have zero girl friend .', 'I have zero girl friend . E'],
    # ['我 有 零 个 好 朋 友 。 P', 'S I have zero good friend .', 'I have zero good friend . E'],
    ['我 有 零 个 男 朋 友 。 P', 'S I have zero boy friend .', 'I have zero boy friend . E']
]

# 测试集(希望transformer能达到的效果)
# 输入:"我 有 零 个 好 朋 友 。"
# 输出:"i have a girl friend ."
# 建立词库
# 中文和英语的单词要分开建立词库
# Padding Should be Zero
src_vocab = {'P': 0, '我': 1, '有': 2, '一': 3,
             '个': 4, '好': 5, '朋': 6, '友': 7, '零': 8, '女': 9, '男': 10, '。':11}
src_idx2word = {i: w for i, w in enumerate(src_vocab)}
src_vocab_size = len(src_vocab)

tgt_vocab = {'P': 0, 'I': 1, 'have': 2, 'a': 3, 'good': 4,
             'friend': 5, 'zero': 6, 'girl': 7,  'boy': 8, 'S': 9, 'E': 10, '.': 11}
tgt_idx2word = {i: w for i, w in enumerate(tgt_vocab)}
tgt_vocab_size = len(tgt_vocab)

src_len = 9     # (源句子的长度)enc_input max sequence length
tgt_len = 7     # dec_input(=dec_output) max sequence length
# 超参数
device = 'cuda'
epochs = 100

d_model = 512   # Embedding Size(token embedding和position编码的维度)
d_ff = 2048     # FeedForward dimension (两次线性层中的隐藏层 512->2048->512,线性层是用来做特征提取的),当然最后会再接一个projection层
d_k = d_v = 64  # dimension of K(=Q), V(Q和K的维度需要相同(点积运算),这里为了方便让K=V)
n_layers = 6    # number of Encoder of Decoder Layer(Block的个数)
n_heads = 8     # number of heads in Multi-Head Attention(有几套头)
def make_data(sentences):
    """把单词序列转换为数字序列"""
    enc_inputs, dec_inputs, dec_outputs = [], [], []
    for i in range(len(sentences)):
 
        enc_input = [src_vocab[n] for n in sentences[i][0].split()]
        dec_input = [tgt_vocab[n] for n in sentences[i][1].split()]
        dec_output = [tgt_vocab[n] for n in sentences[i][2].split()]

        enc_inputs.append(enc_input) 
        dec_inputs.append(dec_input)
        dec_outputs.append(dec_output)

    return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)

enc_inputs, dec_inputs, dec_outputs = make_data(sentences)

print("enc_inputs.shape:\n",enc_inputs.shape)
print("enc_inputs:\n",enc_inputs)
enc_inputs.shape:
 torch.Size([5, 9])
enc_inputs:
 tensor([[ 1,  2,  3,  4,  9,  6,  7, 11,  0],
        [ 1,  2,  3,  4,  5,  6,  7, 11,  0],
        [ 1,  2,  3,  4, 10,  6,  7, 11,  0],
        [ 1,  2,  8,  4,  9,  6,  7, 11,  0],
        [ 1,  2,  8,  4, 10,  6,  7, 11,  0]])
class MyDataSet(Data.Dataset):
    """自定义DataLoader"""

    def __init__(self, enc_inputs, dec_inputs, dec_outputs):
        super(MyDataSet, self).__init__()
        self.enc_inputs = enc_inputs
        self.dec_inputs = dec_inputs
        self.dec_outputs = dec_outputs

    def __len__(self):
        return self.enc_inputs.shape[0]

    def __getitem__(self, idx):
        return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]


loader = Data.DataLoader(
    MyDataSet(enc_inputs, dec_inputs, dec_outputs), 2, True)

test_loader = next(iter(loader))
print(test_loader[0])
tensor([[ 1,  2,  8,  4,  9,  6,  7, 11,  0],
        [ 1,  2,  8,  4, 10,  6,  7, 11,  0]])

位置编码

image-2.png

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(
            0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1) # [max_len, d_model] => [max_len, 1, d_model]
        self.register_buffer('pe', pe) # 参数不更新

    def forward(self, x):
        """
        x: [seq_len, batch_size, d_model]
        """
        x = x + self.pe[:x.size(0), ...]
        return self.dropout(x)
# debug
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
div_term.shape
torch.arange(0, d_model, 2).shape
p = PositionalEncoding(d_model)
em = nn.Embedding(src_vocab_size,d_model)
a = em(test_loader[0])
a.shape
enc_outputs = p(a.transpose(0, 1)).transpose(
            0, 1)  # [batch_size, src_len, d_model]
print(enc_outputs.shape,p.pe[:a.size(0), ...].shape)
torch.Size([2, 9, 512]) torch.Size([2, 1, 512])
def get_attn_pad_mask(seq_q, seq_k):
    # pad mask的作用:在对value向量加权平均的时候,可以让pad对应的alpha_ij=0,这样注意力就不会考虑到pad向量
    # 这里的q,k表示的是两个序列(跟注意力机制的q,k没有关系),
    # 例如encoder_inputs (x1,x2,..xm)和dncoder_inputs (x1,x2..xn)
    # encoder和decoder都可能调用这个函数,所以seq_len视情况而定
    """
    seq_q: [batch_size, len_q]  len_q用来expand维度
    seq_k: [batch_size, len_k]
    """
    batch_size, len_q = seq_q.size()  # 这个seq_q只是用来expand维度的
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    # 例如:seq_k = [[1,2,3,4,0], [1,2,3,5,0]], True is masked
    # [batch_size, 1, len_k]
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)
    # [batch_size, len_q, len_k] 构成一个立方体(batch_size个这样的矩阵)
    return pad_attn_mask.expand(batch_size, len_q, len_k)

test_seq_q = torch.tensor([[1,2,3,4,0], [1,2,0,0,0]])
test_seq_k = torch.tensor([[1,2,3,0,0], [1,2,4,5,0]])
print("test get_attn_pad_mask():\n",get_attn_pad_mask(test_seq_q, test_seq_k))
test get_attn_pad_mask():
 tensor([[[False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True]],

        [[False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True]]])
# 上三角掩码
def get_attn_subsequence_mask(seq):
    """
    seq: [batch_size, tgt_len]
    """
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
    # attn_shape: [batch_size, tgt_len, tgt_len]
    subsequence_mask = np.triu(np.ones(attn_shape), k=1)  # 生成一个上三角矩阵
    subsequence_mask = torch.from_numpy(subsequence_mask).data.eq(1)
    return subsequence_mask  # [batch_size, tgt_len, tgt_len]

test_seq = torch.tensor([[1,2,3,4,0], [1,2,0,0,0]])
print("test get_attn_subsequence_mask():\n",get_attn_subsequence_mask(test_seq))
test get_attn_subsequence_mask():
 tensor([[[False,  True,  True,  True,  True],
         [False, False,  True,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False, False,  True],
         [False, False, False, False, False]],

        [[False,  True,  True,  True,  True],
         [False, False,  True,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False, False,  True],
         [False, False, False, False, False]]])

注意力机制

image.png

# 点积缩放注意力
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        """
        Q: [batch_size, n_heads, len_q, d_k]
        K: [batch_size, n_heads, len_k, d_k]
        V: [batch_size, n_heads, len_v(=len_k), d_v]
        attn_mask: [batch_size, n_heads, seq_len, seq_len]
        注意:len_q(q1,..qt)和len_k(k1,...km)可能不同
        """
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)
        # scores : [batch_size, n_heads, len_q, len_k]
        
        # Fills elements of self tensor with value where mask is True.
        scores.masked_fill_(attn_mask, -1e9)

        attn = nn.Softmax(dim=-1)(scores)  # 对最后一个维度(v)做softmax
        # scores : [batch_size, n_heads, len_q, len_k] * V: [batch_size, n_heads, len_v(=len_k), d_v]

        # context: [batch_size, n_heads, len_q, d_v]
        context = torch.matmul(attn, V)
        # attn注意力稀疏矩阵(用于可视化的)
        return context, attn
# 多头注意力
class MultiHeadAttention(nn.Module):
    """这个Attention类可以实现:
    Encoder的Self-Attention
    Decoder的Masked Self-Attention
    Encoder-Decoder的Attention
    输入:[seq_len, d_model]
    输出:[seq_len, d_model]
    """

    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)  # q,k必须维度相同,不然无法做点积
        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
        # 这个全连接层可以保证多头attention的输出仍然是[seq_len, d_model]
        self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)

    def forward(self, input_Q, input_K, input_V, attn_mask):
        """
        input_Q: [batch_size, len_q, d_model]
        input_K: [batch_size, len_k, d_model]
        input_V: [batch_size, len_v(=len_k), d_model]
        attn_mask: [batch_size, seq_len, seq_len]
        """
        residual, batch_size = input_Q, input_Q.size(0)
        # 下面的多头的参数矩阵是放在一起做线性变换的,然后再拆成多个头,这是工程实现的技巧
        # B: batch_size, S:seq_len, D: dim
        # (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, Head, W) -trans-> (B, Head, S, W)
        #           线性变换               拆成多头

        # Q: [batch_size, n_heads, len_q, d_k]
        Q = self.W_Q(input_Q).view(batch_size, -1,
                                   n_heads, d_k).transpose(1, 2)
        # K: [batch_size, n_heads, len_k, d_k] # K和V的长度一定相同,维度可以不同
        K = self.W_K(input_K).view(batch_size, -1,
                                   n_heads, d_k).transpose(1, 2)
        # V: [batch_size, n_heads, len_v(=len_k), d_v]
        V = self.W_V(input_V).view(batch_size, -1,
                                   n_heads, d_v).transpose(1, 2)

        # 因为是多头,所以mask矩阵要扩充成4维的
        # attn_mask: [batch_size, seq_len, seq_len] -> [batch_size, n_heads, seq_len, seq_len]
        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)

        # context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]
        context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)
        # 下面将不同头的输出向量拼接在一起
        # context: [batch_size, n_heads, len_q, d_v] -> [batch_size, len_q, n_heads * d_v]
        context = context.transpose(1, 2).reshape(
            batch_size, -1, n_heads * d_v)

        # 这个全连接层可以保证多头attention的输出仍然是[seq_len, d_model]
        output = self.fc(context)  # [batch_size, len_q, d_model]
        return nn.LayerNorm(d_model).to(device)(output + residual), attn

FFN

# Pytorch中的Linear只会对最后一维操作,所以正好是我们希望的每个位置用同一个全连接网络
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False),
            nn.ReLU(),
            nn.Linear(d_ff, d_model, bias=False)
        )

    def forward(self, inputs):
        """
        inputs: [batch_size, seq_len, d_model]
        """
        residual = inputs
        output = self.fc(inputs)
        # [batch_size, seq_len, d_model]
        return nn.LayerNorm(d_model).to(device)(output + residual)

EncoderLayer&DecoderLayer

image.png

class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, enc_inputs, enc_self_attn_mask):
        """E
        enc_inputs: [batch_size, src_len, d_model]
        enc_self_attn_mask: [batch_size, src_len, src_len]  mask矩阵(pad mask or sequence mask)
        """
        # enc_outputs: [batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len]
        # 第一个enc_inputs * W_Q = Q
        # 第二个enc_inputs * W_K = K
        # 第三个enc_inputs * W_V = V
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs,
                                               enc_self_attn_mask)  # enc_inputs to same Q,K,V(未线性变换前)
        enc_outputs = self.pos_ffn(enc_outputs)
        # enc_outputs: [batch_size, src_len, d_model]
        return enc_outputs, attn
    
class DecoderLayer(nn.Module):
    def __init__(self):
        super(DecoderLayer, self).__init__()
        self.dec_self_attn = MultiHeadAttention()
        self.dec_enc_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
        """
        dec_inputs: [batch_size, tgt_len, d_model]
        enc_outputs: [batch_size, src_len, d_model]
        dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
        dec_enc_attn_mask: [batch_size, tgt_len, src_len]
        """
        # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs,
                                                        dec_self_attn_mask)  # 这里的Q,K,V全是Decoder自己的输入
        # dec_outputs: [batch_size, tgt_len, d_model], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
        dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs,
                                                      dec_enc_attn_mask)  # Attention层的Q(来自decoder) 和 K,V(来自encoder)
        # [batch_size, tgt_len, d_model]
        dec_outputs = self.pos_ffn(dec_outputs)
        # dec_self_attn, dec_enc_attn这两个是为了可视化的
        return dec_outputs, dec_self_attn, dec_enc_attn

Encoder&Decoder

image.png

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.src_emb = nn.Embedding(src_vocab_size, d_model)  # token Embedding
        self.pos_emb = PositionalEncoding(
            d_model)  # Transformer中位置编码时固定的,不需要学习
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])

    def forward(self, enc_inputs):
        """
        enc_inputs: [batch_size, src_len]
        """
        enc_outputs = self.src_emb(
            enc_inputs)  # [batch_size, src_len, d_model]
        enc_outputs = self.pos_emb(enc_outputs.transpose(0, 1)).transpose(
            0, 1)  # [batch_size, src_len, d_model]
        # Encoder输入序列的pad mask矩阵
        enc_self_attn_mask = get_attn_pad_mask(
            enc_inputs, enc_inputs)  # [batch_size, src_len, src_len]
        enc_self_attns = []  # 在计算中不需要用到,它主要用来保存你接下来返回的attention的值(这个主要是为了你画热力图等,用来看各个词之间的关系
        for layer in self.layers:  # for循环访问nn.ModuleList对象
            # 上一个block的输出enc_outputs作为当前block的输入
            # enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len]
            enc_outputs, enc_self_attn = layer(enc_outputs,
                                               enc_self_attn_mask)  # 传入的enc_outputs其实是input,传入mask矩阵是因为你要做self attention
            enc_self_attns.append(enc_self_attn)  # 这个只是为了可视化
        return enc_outputs, enc_self_attns

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.tgt_emb = nn.Embedding(
            tgt_vocab_size, d_model)  # Decoder输入的embed词表
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([DecoderLayer()
                                    for _ in range(n_layers)])  # Decoder的blocks

    def forward(self, dec_inputs, enc_inputs, enc_outputs):
        """
        dec_inputs: [batch_size, tgt_len]
        enc_inputs: [batch_size, src_len]
        enc_outputs: [batch_size, src_len, d_model]   # 用在Encoder-Decoder Attention层
        """
        dec_outputs = self.tgt_emb(
            dec_inputs)  # [batch_size, tgt_len, d_model]
        dec_outputs = self.pos_emb(dec_outputs.transpose(0, 1)).transpose(0, 1).to(
            device)  # [batch_size, tgt_len, d_model]
        # Decoder输入序列的pad mask矩阵(这个例子中decoder是没有加pad的,实际应用中都是有pad填充的)
        dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).to(
            device)  # [batch_size, tgt_len, tgt_len]
        # Masked Self_Attention:当前时刻是看不到未来的信息的
        dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).to(
            device)  # [batch_size, tgt_len, tgt_len]

        # Decoder中把两种mask矩阵相加(既屏蔽了pad的信息,也屏蔽了未来时刻的信息)
        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequence_mask),
                                      0).to(device)  # [batch_size, tgt_len, tgt_len]; torch.gt比较两个矩阵的元素,大于则返回1,否则返回0

        # 这个mask主要用于encoder-decoder attention层
        # get_attn_pad_mask主要是enc_inputs的pad mask矩阵(因为enc是处理K,V的,求Attention时是用v1,v2,..vm去加权的,要把pad对应的v_i的相关系数设为0,这样注意力就不会关注pad向量)
        #                       dec_inputs只是提供expand的size的
        dec_enc_attn_mask = get_attn_pad_mask(
            dec_inputs, enc_inputs)  # [batch_size, tgt_len, src_len]

        dec_self_attns, dec_enc_attns = [], []
        for layer in self.layers:
            # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
            # Decoder的Block是上一个Block的输出dec_outputs(变化)和Encoder网络的输出enc_outputs(固定)
            dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask,
                                                             dec_enc_attn_mask)
            dec_self_attns.append(dec_self_attn)
            dec_enc_attns.append(dec_enc_attn)
        # dec_outputs: [batch_size, tgt_len, d_model]
        return dec_outputs, dec_self_attns, dec_enc_attns

Transformer

image.png


class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        self.encoder = Encoder().to(device)
        self.decoder = Decoder().to(device)
        self.projection = nn.Linear(
            d_model, tgt_vocab_size, bias=False).to(device)

    def forward(self, enc_inputs, dec_inputs):
        """Transformers的输入:两个序列
        enc_inputs: [batch_size, src_len]
        dec_inputs: [batch_size, tgt_len]
        """
        # tensor to store decoder outputs
        # outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(self.device)

        # enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
        # 经过Encoder网络后,得到的输出还是[batch_size, src_len, d_model]
        enc_outputs, enc_self_attns = self.encoder(enc_inputs)
        # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attns: [n_layers, batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [n_layers, batch_size, tgt_len, src_len]
        dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(
            dec_inputs, enc_inputs, enc_outputs)
        # dec_outputs: [batch_size, tgt_len, d_model] -> dec_logits: [batch_size, tgt_len, tgt_vocab_size]
        dec_logits = self.projection(dec_outputs)
        return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns


训练

model = Transformer().to(device)
# 这里的损失函数里面设置了一个参数 ignore_index=0,因为 "pad" 这个单词的索引为 0,这样设置以后,就不会计算 "pad" 的损失(因为本来 "pad" 也没有意义,不需要计算)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=1e-3,
                      momentum=0.99)  # 用adam的话效果不好
for epoch in range(epochs):
    for enc_inputs, dec_inputs, dec_outputs in loader:
        """
        enc_inputs: [batch_size, src_len]
        dec_inputs: [batch_size, tgt_len]
        dec_outputs: [batch_size, tgt_len]
        """
        enc_inputs, dec_inputs, dec_outputs = enc_inputs.to(
            device), dec_inputs.to(device), dec_outputs.to(device)
        # outputs: [batch_size * tgt_len, tgt_vocab_size]
        outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(
            enc_inputs, dec_inputs)
        # dec_outputs.view(-1):[batch_size * tgt_len * tgt_vocab_size]
        loss = criterion(outputs, dec_outputs.view(-1))
        print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Epoch: 0001 loss = 2.444907
Epoch: 0001 loss = 2.334698
Epoch: 0001 loss = 2.093289
Epoch: 0002 loss = 1.812472
Epoch: 0002 loss = 1.675709
Epoch: 0002 loss = 1.458901
Epoch: 0003 loss = 1.401762
Epoch: 0003 loss = 1.369465
Epoch: 0003 loss = 0.982791
Epoch: 0004 loss = 0.966825
Epoch: 0004 loss = 0.952247
Epoch: 0004 loss = 0.873539
Epoch: 0005 loss = 0.627767
Epoch: 0005 loss = 0.795261
Epoch: 0005 loss = 0.421385
Epoch: 0006 loss = 0.331635
Epoch: 0006 loss = 0.696892
Epoch: 0006 loss = 0.285732
Epoch: 0007 loss = 0.275137
Epoch: 0007 loss = 0.602334
Epoch: 0007 loss = 0.273436
Epoch: 0008 loss = 0.179638
Epoch: 0008 loss = 0.506163
Epoch: 0008 loss = 0.233821
Epoch: 0009 loss = 0.226149
Epoch: 0009 loss = 0.239603
Epoch: 0009 loss = 0.203738
Epoch: 0010 loss = 0.187203
Epoch: 0010 loss = 0.435071
Epoch: 0010 loss = 0.069702
Epoch: 0011 loss = 0.314074
Epoch: 0011 loss = 0.239697
Epoch: 0011 loss = 0.401073
Epoch: 0012 loss = 0.393589
Epoch: 0012 loss = 0.223106
Epoch: 0012 loss = 0.223769
Epoch: 0013 loss = 0.282822
Epoch: 0013 loss = 0.257790
Epoch: 0013 loss = 0.048830
Epoch: 0014 loss = 0.312652
Epoch: 0014 loss = 0.112380
Epoch: 0014 loss = 0.424507
Epoch: 0015 loss = 0.139422
Epoch: 0015 loss = 0.274820
Epoch: 0015 loss = 0.298400
Epoch: 0016 loss = 0.217399
Epoch: 0016 loss = 0.112750
Epoch: 0016 loss = 0.067188
Epoch: 0017 loss = 0.065543
Epoch: 0017 loss = 0.417677
Epoch: 0017 loss = 0.256891
Epoch: 0018 loss = 0.110178
Epoch: 0018 loss = 0.307671
Epoch: 0018 loss = 0.564578
Epoch: 0019 loss = 0.162470
Epoch: 0019 loss = 0.190827
Epoch: 0019 loss = 0.188836
Epoch: 0020 loss = 0.101207
Epoch: 0020 loss = 0.172687
Epoch: 0020 loss = 0.157646
Epoch: 0021 loss = 0.081377
Epoch: 0021 loss = 0.317234
Epoch: 0021 loss = 0.190117
Epoch: 0022 loss = 0.417913
Epoch: 0022 loss = 0.060791
Epoch: 0022 loss = 0.108195
Epoch: 0023 loss = 0.073002
Epoch: 0023 loss = 0.402348
Epoch: 0023 loss = 0.009389
Epoch: 0024 loss = 0.103685
Epoch: 0024 loss = 0.238067
Epoch: 0024 loss = 0.008065
Epoch: 0025 loss = 0.176196
Epoch: 0025 loss = 0.026496
Epoch: 0025 loss = 0.075038
Epoch: 0026 loss = 0.062269
Epoch: 0026 loss = 0.040344
Epoch: 0026 loss = 0.243392
Epoch: 0027 loss = 0.089456
Epoch: 0027 loss = 0.014893
Epoch: 0027 loss = 0.005858
Epoch: 0028 loss = 0.076116
Epoch: 0028 loss = 0.002074
Epoch: 0028 loss = 0.291679
Epoch: 0029 loss = 0.064219
Epoch: 0029 loss = 0.131802
Epoch: 0029 loss = 0.000633
Epoch: 0030 loss = 0.000856
Epoch: 0030 loss = 0.225335
Epoch: 0030 loss = 0.027927
Epoch: 0031 loss = 0.281215
Epoch: 0031 loss = 0.007803
Epoch: 0031 loss = 0.021516
Epoch: 0032 loss = 0.190256
Epoch: 0032 loss = 0.002374
Epoch: 0032 loss = 0.001287
Epoch: 0033 loss = 0.001356
Epoch: 0033 loss = 0.084635
Epoch: 0033 loss = 0.001708
Epoch: 0034 loss = 0.001284
Epoch: 0034 loss = 0.049037
Epoch: 0034 loss = 0.001768
Epoch: 0035 loss = 0.001370
Epoch: 0035 loss = 0.092342
Epoch: 0035 loss = 0.005393
Epoch: 0036 loss = 0.022871
Epoch: 0036 loss = 0.038931
Epoch: 0036 loss = 0.003027
Epoch: 0037 loss = 0.026564
Epoch: 0037 loss = 0.005689
Epoch: 0037 loss = 0.005196
Epoch: 0038 loss = 0.059345
Epoch: 0038 loss = 0.004768
Epoch: 0038 loss = 0.036875
Epoch: 0039 loss = 0.108965
Epoch: 0039 loss = 0.040516
Epoch: 0039 loss = 0.000462
Epoch: 0040 loss = 0.037638
Epoch: 0040 loss = 0.007827
Epoch: 0040 loss = 0.004610
Epoch: 0041 loss = 0.023541
Epoch: 0041 loss = 0.003524
Epoch: 0041 loss = 0.187471
Epoch: 0042 loss = 0.087047
Epoch: 0042 loss = 0.056470
Epoch: 0042 loss = 0.000788
Epoch: 0043 loss = 0.099710
Epoch: 0043 loss = 0.028800
Epoch: 0043 loss = 0.005353
Epoch: 0044 loss = 0.017328
Epoch: 0044 loss = 0.043522
Epoch: 0044 loss = 0.006211
Epoch: 0045 loss = 0.041052
Epoch: 0045 loss = 0.001367
Epoch: 0045 loss = 0.001629
Epoch: 0046 loss = 0.004877
Epoch: 0046 loss = 0.024950
Epoch: 0046 loss = 0.000697
Epoch: 0047 loss = 0.106069
Epoch: 0047 loss = 0.004080
Epoch: 0047 loss = 0.001662
Epoch: 0048 loss = 0.001224
Epoch: 0048 loss = 0.003590
Epoch: 0048 loss = 0.095277
Epoch: 0049 loss = 0.003333
Epoch: 0049 loss = 0.019699
Epoch: 0049 loss = 0.001417
Epoch: 0050 loss = 0.001680
Epoch: 0050 loss = 0.002900
Epoch: 0050 loss = 0.003057
Epoch: 0051 loss = 0.003243
Epoch: 0051 loss = 0.003001
Epoch: 0051 loss = 0.001339
Epoch: 0052 loss = 0.002604
Epoch: 0052 loss = 0.006472
Epoch: 0052 loss = 0.005859
Epoch: 0053 loss = 0.003660
Epoch: 0053 loss = 0.016574
Epoch: 0053 loss = 0.013217
Epoch: 0054 loss = 0.006983
Epoch: 0054 loss = 0.007660
Epoch: 0054 loss = 0.008758
Epoch: 0055 loss = 0.018442
Epoch: 0055 loss = 0.004823
Epoch: 0055 loss = 0.005564
Epoch: 0056 loss = 0.005049
Epoch: 0056 loss = 0.010866
Epoch: 0056 loss = 0.003027
Epoch: 0057 loss = 0.033244
Epoch: 0057 loss = 0.009304
Epoch: 0057 loss = 0.005155
Epoch: 0058 loss = 0.028059
Epoch: 0058 loss = 0.003198
Epoch: 0058 loss = 0.017911
Epoch: 0059 loss = 0.004290
Epoch: 0059 loss = 0.006959
Epoch: 0059 loss = 0.000333
Epoch: 0060 loss = 0.004318
Epoch: 0060 loss = 0.002820
Epoch: 0060 loss = 0.025311
Epoch: 0061 loss = 0.024404
Epoch: 0061 loss = 0.005432
Epoch: 0061 loss = 0.000047
Epoch: 0062 loss = 0.054393
Epoch: 0062 loss = 0.000886
Epoch: 0062 loss = 0.000078
Epoch: 0063 loss = 0.036457
Epoch: 0063 loss = 0.010535
Epoch: 0063 loss = 0.014770
Epoch: 0064 loss = 0.000291
Epoch: 0064 loss = 0.014036
Epoch: 0064 loss = 0.027240
Epoch: 0065 loss = 0.001039
Epoch: 0065 loss = 0.013150
Epoch: 0065 loss = 0.026792
Epoch: 0066 loss = 0.019471
Epoch: 0066 loss = 0.001618
Epoch: 0066 loss = 0.016640
Epoch: 0067 loss = 0.009114
Epoch: 0067 loss = 0.002208
Epoch: 0067 loss = 0.005768
Epoch: 0068 loss = 0.001546
Epoch: 0068 loss = 0.002858
Epoch: 0068 loss = 0.002059
Epoch: 0069 loss = 0.001180
Epoch: 0069 loss = 0.003597
Epoch: 0069 loss = 0.001288
Epoch: 0070 loss = 0.002646
Epoch: 0070 loss = 0.002658
Epoch: 0070 loss = 0.004571
Epoch: 0071 loss = 0.002683
Epoch: 0071 loss = 0.002644
Epoch: 0071 loss = 0.005884
Epoch: 0072 loss = 0.000819
Epoch: 0072 loss = 0.002505
Epoch: 0072 loss = 0.006722
Epoch: 0073 loss = 0.005530
Epoch: 0073 loss = 0.001527
Epoch: 0073 loss = 0.000118
Epoch: 0074 loss = 0.004990
Epoch: 0074 loss = 0.002871
Epoch: 0074 loss = 0.001252
Epoch: 0075 loss = 0.000809
Epoch: 0075 loss = 0.006991
Epoch: 0075 loss = 0.002197
Epoch: 0076 loss = 0.000221
Epoch: 0076 loss = 0.002987
Epoch: 0076 loss = 0.002131
Epoch: 0077 loss = 0.000891
Epoch: 0077 loss = 0.003040
Epoch: 0077 loss = 0.001191
Epoch: 0078 loss = 0.000241
Epoch: 0078 loss = 0.002150
Epoch: 0078 loss = 0.006512
Epoch: 0079 loss = 0.002099
Epoch: 0079 loss = 0.000765
Epoch: 0079 loss = 0.000012
Epoch: 0080 loss = 0.000059
Epoch: 0080 loss = 0.003267
Epoch: 0080 loss = 0.001664
Epoch: 0081 loss = 0.002103
Epoch: 0081 loss = 0.003268
Epoch: 0081 loss = 0.000175
Epoch: 0082 loss = 0.001716
Epoch: 0082 loss = 0.000571
Epoch: 0082 loss = 0.000016
Epoch: 0083 loss = 0.000502
Epoch: 0083 loss = 0.000572
Epoch: 0083 loss = 0.002209
Epoch: 0084 loss = 0.001335
Epoch: 0084 loss = 0.001363
Epoch: 0084 loss = 0.002480
Epoch: 0085 loss = 0.001117
Epoch: 0085 loss = 0.000618
Epoch: 0085 loss = 0.000094
Epoch: 0086 loss = 0.000166
Epoch: 0086 loss = 0.000221
Epoch: 0086 loss = 0.000893
Epoch: 0087 loss = 0.000046
Epoch: 0087 loss = 0.000139
Epoch: 0087 loss = 0.002012
Epoch: 0088 loss = 0.001610
Epoch: 0088 loss = 0.000137
Epoch: 0088 loss = 0.000092
Epoch: 0089 loss = 0.000516
Epoch: 0089 loss = 0.000052
Epoch: 0089 loss = 0.000045
Epoch: 0090 loss = 0.000147
Epoch: 0090 loss = 0.000875
Epoch: 0090 loss = 0.000081
Epoch: 0091 loss = 0.000062
Epoch: 0091 loss = 0.001444
Epoch: 0091 loss = 0.000064
Epoch: 0092 loss = 0.000133
Epoch: 0092 loss = 0.003250
Epoch: 0092 loss = 0.000123
Epoch: 0093 loss = 0.002410
Epoch: 0093 loss = 0.000098
Epoch: 0093 loss = 0.000065
Epoch: 0094 loss = 0.002879
Epoch: 0094 loss = 0.000046
Epoch: 0094 loss = 0.000010
Epoch: 0095 loss = 0.001614
Epoch: 0095 loss = 0.000076
Epoch: 0095 loss = 0.000250
Epoch: 0096 loss = 0.000199
Epoch: 0096 loss = 0.003597
Epoch: 0096 loss = 0.000101
Epoch: 0097 loss = 0.000119
Epoch: 0097 loss = 0.004447
Epoch: 0097 loss = 0.000014
Epoch: 0098 loss = 0.001480
Epoch: 0098 loss = 0.000134
Epoch: 0098 loss = 0.000026
Epoch: 0099 loss = 0.000760
Epoch: 0099 loss = 0.000048
Epoch: 0099 loss = 0.000210
Epoch: 0100 loss = 0.000086
Epoch: 0100 loss = 0.000100
Epoch: 0100 loss = 0.000513

预测


def greedy_decoder(model, enc_input, start_symbol):
    """贪心编码
    For simplicity, a Greedy Decoder is Beam search when K=1. This is necessary for inference as we don't know the
    target sequence input. Therefore we try to generate the target input word by word, then feed it into the transformer.
    Starting Reference: http://nlp.seas.harvard.edu/2018/04/03/attention.html#greedy-decoding
    :param model: Transformer Model
    :param enc_input: The encoder input
    :param start_symbol: The start symbol. In this example it is 'S' which corresponds to index 4
    :return: The target input
    """
    enc_outputs, enc_self_attns = model.encoder(enc_input)
    # 初始化一个空的tensor: tensor([], size=(1, 0), dtype=torch.int64)
    dec_input = torch.zeros(1, 0).type_as(enc_input.data)
    terminal = False
    next_symbol = start_symbol
    while not terminal:
        # 预测阶段:dec_input序列会一点点变长(每次添加一个新预测出来的单词)
        dec_input = torch.cat([dec_input.to(device), torch.tensor([[next_symbol]], dtype=enc_input.dtype).to(device)],
                              -1)
        dec_outputs, _, _ = model.decoder(dec_input, enc_input, enc_outputs)
        projected = model.projection(dec_outputs)
        prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]
        # 增量更新(我们希望重复单词预测结果是一样的)
        # 我们在预测是会选择性忽略重复的预测的词,只摘取最新预测的单词拼接到输入序列中
        # 拿出当前预测的单词(数字)。我们用x'_t对应的输出z_t去预测下一个单词的概率,不用z_1,z_2..z_{t-1}
        next_word = prob.data[-1]
        next_symbol = next_word
        if next_symbol == tgt_vocab["E"]:
            terminal = True
        # print(next_word)
        
    greedy_dec_predict = dec_input[:, 1:]
    return greedy_dec_predict
# 预测阶段
sentences = [
    # enc_input                dec_input           dec_output
    ['我 有 一 个 男 朋 友 P 。', '', ''],
    ['我 有 零 个 男 朋 友 P 。', '', ''],
    ['我 有 零 个 女 朋 友 P 。', '', ''],
]

enc_inputs, dec_inputs, dec_outputs = make_data(sentences)
test_loader = Data.DataLoader(
    MyDataSet(enc_inputs, dec_inputs, dec_outputs), 3, True)
enc_inputs, _, _ = next(iter(test_loader))

print()
print("="*30)
print("利用训练好的Transformer模型将中文句子翻译成英文句子: ")
for i in range(len(enc_inputs)):
    greedy_dec_predict = greedy_decoder(model, enc_inputs[i].view(
        1, -1).to(device), start_symbol=tgt_vocab["S"])
    print(enc_inputs[i], '->', greedy_dec_predict.squeeze())
    print([src_idx2word[t.item()] for t in enc_inputs[i]], '->',
          [tgt_idx2word[n.item()] for n in greedy_dec_predict.squeeze()])
==============================
利用训练好的Transformer模型将中文句子翻译成英文句子: 
tensor([ 1,  2,  3,  4, 10,  6,  7,  0, 11]) -> tensor([ 1,  2,  3,  8,  5, 11], device='cuda:0')
['我', '有', '一', '个', '男', '朋', '友', 'P', '。'] -> ['I', 'have', 'a', 'boy', 'friend', '.']
tensor([ 1,  2,  8,  4,  9,  6,  7,  0, 11]) -> tensor([ 1,  2,  6,  7,  5, 11], device='cuda:0')
['我', '有', '零', '个', '女', '朋', '友', 'P', '。'] -> ['I', 'have', 'zero', 'girl', 'friend', '.']
tensor([ 1,  2,  8,  4, 10,  6,  7,  0, 11]) -> tensor([ 1,  2,  6,  8,  5, 11], device='cuda:0')
['我', '有', '零', '个', '男', '朋', '友', 'P', '。'] -> ['I', 'have', 'zero', 'boy', 'friend', '.']
posted on 2024-05-28 22:03  LittleHenry  阅读(16)  评论(0编辑  收藏  举报