【文本摘要项目】3-基于tensorflow2的seq2seq模型

背景

    基于前两篇文章数据预处理和数据集构造的过程后,可以开始针对我们的文本摘要任务进行处理了。这里选用的是经典的seq2seq+attention模型,作为我们的baseline,先完成一个跑通吧。后续基于seq2seq进行其他改进。

核心内容

    seq2seq模型的理论,此处不再介绍。简单描述下,其有两(三)部分组成:EncoderDecoder。当前提到seq2seq模型,一般都是Attention放到一起说的,因此也可以认为该模型有三部分组成,其中第三部分就是所谓的注意力机制Attention。下面分别基于tensorflow2实现对应部分的内容。

Encoder层

    项目中中的词向量,是基于gensimword2vec接口,进行训练并保存到本地的文件;encoder使用GRU模型进行编码。具体内容前面文章已描述。
    encoder层初始化的输入有:embedding matrix,即词向量矩阵;enc_units,即GRU模型对句子编码后的输出维度;以及batch_size大小,因为tensorflow都是用batch进行批量训练的,因此需要指定一个batch维度。encoder前向传播时,接受的输入为:x,即某一个批次的词向量,以及隐变量hidden。具体代码如下:

class Encoder(keras.Model):

    def __init__(self, embedding_matrix, enc_units, batch_sz):
        super(Encoder, self).__init__()
        self.batch_sz = batch_sz
        self.enc_units = enc_units
        vocab_size, embedding_dim = embedding_matrix.shape

        self.embedding = keras.layers.Embedding(vocab_size,
                                                embedding_dim,
                                                weights=[embedding_matrix],
                                                trainable=False)

        self.gru = keras.layers.GRU(units=self.enc_units,
                                    return_state=True,
                                    return_sequences=True,
                                    recurrent_initializer='glorot_uniform')

    def call(self, x, hidden):

        # embedding前x维度:batch_size * max_len -> 32 * 341
        x = self.embedding(x)
        # embedding后x维度:batch_size * max_len * embedding_dim -> 32 * 341 * 300

        # output 维度:batch_size * max_len * enc_units  -> 32 * 341 * 400
        # state  维度:batch_size * enc_units  -> 32 * 400
        output, state = self.gru(x, initial_state=hidden)

        return output, state
        
    def initialize_hidden_state(self):

        return tf.zeros(shape=(self.batch_sz, self.enc_units))

Decoder层

    decoder层和encoder层的过程基本类似,因为代码结构也类似,不同在于:decoder需要输出一个词汇表长度的概率分布。本项目中decoder层中,也是采用GRU模型,来进行解码。其中,如果模型中要包含Attention机制的话,注意力机制的作用时发生在解码阶段,因此decoder部分还要包括一个Attention层。因此decoder的输入有:词向量x、上一个timestep的输出隐层向量,以及encoder的输出hidden。输出为一个词汇表长度的概率分布。具体代码如下:

class Decoder(keras.Model):

    def __init__(self, embedding_matrix, dec_units, batch_sz):
        super(Decoder, self).__init__()

        self.batch_sz = batch_sz
        self.dec_units = dec_units
        vocab_size, embedding_dim = embedding_matrix.shape

        self.embedding = keras.layers.Embedding(vocab_size,
                                                embedding_dim,
                                                weights=[embedding_matrix],
                                                trainable=False)

        self.gru = keras.layers.GRU(self.dec_units,
                                    return_state=True,
                                    return_sequences=True,
                                    recurrent_initializer='glorot_uniform')

        self.fc = keras.layers.Dense(vocab_size)

        self.attention = BahdanauAttention(self.dec_units)

    def call(self, x, hidden, enc_output):

        # hidden维度:batch_size * dec_units  -> 32 * 400
        # enc_output维度:batch_size * max_len * dec_units -> 32 * 341 * 400
        # context_vector维度:batch_size * dec_units -> 32 * 400
        # attention_weights维度:batch_size * max_len  * 1
        context_vector, attention_weight = self.attention(hidden, enc_output)

        # embedding后x维度:batch_size * 1 * embedding_dim-> 32 * 1 * 300
        x = self.embedding(x)

        # x拼接后的维度:batch_size * 1 * dec_units + embedding_dim -> 32 * 341 * (400 + 300)
        x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)

        # output维度:batch_size * 1 * 400
        # state 维度:batch_size * 400
        output, state = self.gru(x, hidden)

        # output维度:batch_size * 400
        output = tf.reshape(output, shape=(-1, output.shape[2]))

        # prediction维度:batch_size * len(vocab)
        prediction = self.fc(output)

        return prediction, state, attention_weight

Attention层

    attention作用于decoder,其具体原理,本文不再细数。本文的实现采用的时加性注意力的感知机。即:将decoder输出的上一个时间步的隐变量(即query)、以及encoder的编码结果enc_output(即value/key)进行一次线性变换后,进行相加,然后通过一个激活函数tanh后,再经过一个线性变换,输出注意力分数值。因此,Attention需要三个线形层,以及query/key/value的值。具体实现如下:

class BahdanauAttention(keras.Model):

    def __init__(self, units):
        super(BahdanauAttention, self).__init__()

        self.W1 = keras.layers.Dense(units)
        self.W2 = keras.layers.Dense(units)
        self.V = keras.layers.Dense(1)

    def call(self, query, values):

        # query为decoder中,上一个时间步的隐变量St-1
        # values为encoder的编码结果enc_output
        # seq2seq模型中,st是decoder中的query向量;而encoder的隐变量hi是values

        # query 维度:batch_size * dec_units -> 32 * 400
        # values维度:batch_size * max_len * dec_units -> 32 * 341 * 400

        # hidden_with_time_axis维度:batch_size * 1 * dec_units
        hidden_with_time_axis = tf.expand_dims(query, axis=1)

        # self.W1(values): batch_size * max_len * dec_units
        # self.W2(hidden_with_time_axis): batch_size * 1 * dec_units
        # tanh(...)维度:batch_size * max_len * dec_units  tf加法性质:对应相加

        # score维度:batch_size * max_len * 1 -> 32 * 341 * 1
        score = self.V(
            tf.nn.tanh(self.W1(values) + self.W2(hidden_with_time_axis))
        )

        # attention_weights维度:batch_size * max_len  * 1
        attention_weights = tf.nn.softmax(score, axis=1)

        # context_vector维度:batch_size * dec_units -> 32 * 400
        context_vector = attention_weights * values
        context_vector = tf.reduce_sum(context_vector, axis=1)

        return context_vector, attention_weights

    到此为止,构成seq2seq模型所需要的所有组件,基本已经搭建完成,我们可以根据自己的需要,利用组件构造模型了。

Seq2Seq模型

    seq2seq模型的输出为decoder解码出的一系列概率分布,因此采用何种方式进行解码,就显得尤为重要。如贪心解码teacher forcing以及介于两种之间的beam search等。其具体细节不是本文讨论的重点,只大概简述这几个方法的区别。
    贪心解码的思想是,预测 \(t\) 时刻输出的单词时,直接将\(t-1\)时刻的输出词汇表中概率最大的单词,作为\(t\)时刻的输入,因此可能导致如果前一个预测值就不准的话,后面一系列都不准的问题
    Teacher Forcing的方法是,预测 \(t\) 时刻输出的单词时,直接将\(t-1\)时刻的实际单词,作为输入,因此可能带来的问题是,训练过程预测良好(因为有标签,即实际单词),但是测试过程极差(因为测试过程不会给对应的真实单词)。
    实际应用中,往往采用介于这两种极端方式之间的解码方式,如beam search 等,具体思路是预测 \(t\) 时刻输出的单词时,保留\(t-1\)时刻的输出词汇表中概率最大的前K个单词,以此带来更多的可能性(解决第一个方法的缺陷);而且在训练过程,采用一定的概率P,来决定是否使用真实单词作为输入(解决第二个方法的缺陷)。
    本文旨在搭建一个baseline,因此采用了teacher Forcing的方法进行解码,后续再再提升过程中尝试其他方法。代码如下:

class Seq2Seq(keras.Model):

    def __init__(self, params, vocab):
        super(Seq2Seq, self).__init__()

        self.embedding_matrix = load_embedding_matrix()
        self.params = params
        self.vocab = vocab

        self.batch_size = params['batch_size']

        self.enc_units = params['enc_units']
        self.dec_units = params['dec_units']
        self.att_units = params['att_units']

        self.encoder = Encoder(self.embedding_matrix, self.enc_units, self.batch_size)
        self.decoder = Decoder(self.embedding_matrix, self.dec_units, self.batch_size)

        self.attention = BahdanauAttention(self.att_units)

    def teacher_decoder(self, dec_hidden, enc_output, dec_target):

        prediction = []

        # 第一个输入<START>
        dec_input = tf.expand_dims([self.vocab.START_DECODING_INDEX] * self.batch_size, axis=1)

        # teacher forcing 讲target作为下一次的输入,依次解码
        for t in range(1, dec_target.shape[1]):  # dec_target shape: batch_size * max_len
            pred, dec_hidden, _ = self.decoder(dec_input, dec_hidden, enc_output)

            # 预测下一个值需要的输入
            dec_input = tf.expand_dims(dec_target[:, t], axis=1)

            prediction.append(pred)

        return tf.stack(prediction, axis=1), dec_hidden

模型训练Pipeline

    到此,构建基于seq2seq的文本摘要模型基本搭建完成,可以开始训练了。主要函数如下:

def train(params):
    # 1. 配置计算资源
    config_gpus()

    # 2. vocab
    vocab = Vocab(params['vocab_path'], params['vocab_size'])
    params['vocab_size'] = vocab.count

    # 3. 构造模型
    model = Seq2Seq(params, vocab)  # 确保传入到模型的参数params的所有制不会再被修改,不然会报错。

    # # 4. 模型存储
    checkpoint = tf.train.Checkpoint(Seq2Seq=model)
    checkpoint_manager = tf.train.CheckpointManager(checkpoint,
                                                    directory=path.join(params['checkpoint_dir'], 'seq2seq_model'),
                                                    max_to_keep=5)

    # 5. 模型训练
    train_model(model, vocab, params, checkpoint_manager)

    主要训练函数train_model,可分开写在其他文件。样例代码:

# 损失函数
def loss_function(real, pred, pad_index):

    loss_obj = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
    mask = tf.math.logical_not(tf.math.equal(real, pad_index))

    loss_ = loss_obj(real, pred)
    mask = tf.cast(mask, dtype=loss_.dtype)

    loss_ *= mask
    return tf.reduce_mean(loss_)


# 批次训练
def train_step(model, enc_inputs, dec_target, initial_enc_hidden, loss_function=None, optimizer=None, mode='train'):

    with tf.GradientTape() as tape:

        # encoder部分
        enc_output, enc_hidden = model.encoder(enc_inputs, initial_enc_hidden)

        # decoder部分
        initial_dec_hidden = enc_hidden  # 用encoder的最终输出,作为第一个S_0

        # 逐个预测序列
        prediction, _ = model.teacher_decoder(initial_dec_hidden, enc_output, dec_target)

        # 预测损失
        batch_loss = loss_function(dec_target[:, 1:], prediction)

        if mode == 'train':
            variables = (model.encoder.trainable_variables + model.decoder.trainable_variables + model.attention.trainable_variables)
            gradients = tape.gradient(batch_loss, variables)
            gradients, _ = tf.clip_by_global_norm(gradients, clip_norm=5.)

            optimizer.apply_gradients(zip(gradients, variables))

        return batch_loss


# 模型评估
def evaluate_model(model, valid_dataset, valid_steps_per_epoch, pad_index):

    print('starting evaluating ...')

    total_loss = 0
    initial_enc_hidden = model.encoder.initialize_hidden_state()
    for batch, data in enumerate(valid_dataset.take(valid_steps_per_epoch), start=1):

        inputs, target = data

        batch_loss = train_step(model,
                                inputs,
                                target,
                                initial_enc_hidden,
                                loss_function=partial(loss_function, pad_index=pad_index),
                                mode='eval')
        total_loss += batch_loss

    return total_loss / valid_steps_per_epoch


def train_model(model, vocab, params, checkpoint_manager):

    epochs = params['epochs']

    pad_index = vocab.word2index[vocab.PAD_TOKEN]

    optimizer = keras.optimizers.Adam(name='Adam', learning_rate=params['learning_rate'])

    train_dataset, valid_dataset, train_steps_per_epoch, valid_steps_per_epoch = train_batch_generator(params['batch_size'], params['max_enc_len'], params['max_dec_len'], sample_sum=2 ** 7)

    for epoch in range(epochs):
        start_time = time.time()

        # 第一个隐状态h_0
        initial_enc_hidden = model.encoder.initialize_hidden_state()

        total_loss = 0.
        running_loss = 0.
        # 模型训练
        for batch_index, (inputs, target) in enumerate(train_dataset.take(train_steps_per_epoch), start=1):

            batch_loss = train_step(model,
                                    inputs,
                                    target,
                                    initial_enc_hidden,
                                    loss_function=partial(loss_function, pad_index=pad_index),
                                    optimizer=optimizer)

            total_loss += batch_loss

            if batch_index % 5 == 0:
                print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
                                                             batch_index,
                                                             (total_loss - running_loss) / 5))
                running_loss = total_loss

        # 模型保存
        if (epoch + 1) % 1 == 0:
            ckpt_save_path = checkpoint_manager.save()
            print('Saving checkpoint for epoch {} at {}'.format(epoch + 1, ckpt_save_path))

        # 模型验证
        valid_loss = evaluate_model(model, valid_dataset, valid_steps_per_epoch, pad_index)

        print('Epoch {} Loss {:.4f}; val Loss {:.4f}'.format(epoch + 1,
                                                             total_loss / train_steps_per_epoch,
                                                             valid_loss))

        print('Time taken for 1 epoch {} sec\n'.format(time.time() - start_time))

完整代码

完整代码:文本摘要baseline

posted @ 2021-06-26 17:30  温良Miner  阅读(936)  评论(0编辑  收藏  举报
分享到: