【文本摘要项目】5-性能提升之PGN模型

背景

    经过前几篇文章的内容,基本跑通了整个文本摘要的基本流程。主要包括:文本预处理、基于注意力机制的seq2seq文本摘要生成、解码算法、模型生成结果评估等。因此,经过前面的操作,基本可以得到一个完整的文本摘要抽取的过程,本文的内容旨在对抽取过程进行进一步的提升。本文主要实现的是transformer和bert模型之前的一个较为经典的模型——Pointer Generate Network,其理论部分的内容已在其他文章交代,本文重在其代码实现部分。

核心内容

    本文内容分别就数据加载、模型构建、模型训练等几个主要部分,在原来baseline的基础上进行修改,其中模型评估部分,和基于seq2seq模型的过程基本相同,因此不在叙述其具体实现。完整代码在本文最后会附上代码。

整体流程

    整体流程在前面基于seq2seq和Attention的文本摘要模型中已做具体介绍,本文旨在利用新方法,对模型性能进行提升,因此整体架构基本不变,在此不在累述。只在大体流程的部分细节进行优化,例如:使用生成器进行数据加载、学习率及损失函数的改进、Attention计算时考虑mask等,具体将在代码中描述。

基于generator的数据批量加载

# 构造数据
train_dataset, params['train_steps_per_epoch'] = batcher(vocab, params)
valid_dataset, params['valid_steps_per_epoch'] = batcher(vocab, params)
logger.info(f'Building the dataset for train/valid ...')

    其中,batcher为本次加载数据方法的不同之处。

def batcher(vocab, params):
    dataset = batch_generator(example_generator,
                              params,
                              vocab,
                              params['max_enc_len'],
                              params['max_dec_len'],
                              params['batch_size'],
                              params['mode'],
                              )

    dataset = dataset.prefetch(params['buffer_size'])
    steps_per_epoch = get_steps_per_epoch(params)

    return dataset, steps_per_epoch

    example_generator为数据加载时的具体生成器,可根据数据格式,进行具体编写。而batch_generator会根据example_generator的返回的结果,构造成符合之前tensorflow读取数据的dataset样式,因此如何在example_generator中依旧保持符合tensorflow的常用工具dataset就成为核心。整体流程和过去保持一致。具体如下:

def batch_generator(generator, params, vocab, max_enc_len, max_dec_len, batch_size, mode):
    dataset = tf.data.Dataset.from_generator(lambda: generator(params,
                                                               vocab,
                                                               max_enc_len,
                                                               max_dec_len,
                                                               mode,
                                                               # batch_size
                                                               ),
                                             output_types={
                                                 'enc_len': tf.int32,
                                                 'enc_input': tf.int32,
                                                 'enc_input_extend_vocab': tf.int32,
                                                 'article_oovs': tf.string,
                                                 'dec_input': tf.int32,
                                                 'target': tf.int32,
                                                 'dec_len': tf.int32,
                                                 'article': tf.string,
                                                 'abstract': tf.string,
                                                 'abstract_sents': tf.string,
                                                 'decoder_pad_mask': tf.int32,
                                                 'encoder_pad_mask': tf.int32},
                                             output_shapes={
                                                 'enc_len': [],
                                                 'enc_input': [None],
                                                 'enc_input_extend_vocab': [None],
                                                 'article_oovs': [None],
                                                 'dec_input': [None],
                                                 'target': [None],
                                                 'dec_len': [],
                                                 'article': [],
                                                 'abstract': [],
                                                 'abstract_sents': [],
                                                 'decoder_pad_mask': [None],
                                                 'encoder_pad_mask': [None]})

    dataset = dataset.padded_batch(batch_size=batch_size,
                                   padded_shapes=({'enc_len': [],
                                                   'enc_input': [None],
                                                   'enc_input_extend_vocab': [None],
                                                   'article_oovs': [None],
                                                   'dec_input': [max_dec_len],
                                                   'target': [max_dec_len],
                                                   'dec_len': [],
                                                   'article': [],
                                                   'abstract': [],
                                                   'abstract_sents': [],
                                                   'decoder_pad_mask': [max_dec_len],
                                                   'encoder_pad_mask': [None]}),
                                   padding_values={'enc_len': -1,
                                                   'enc_input': vocab.word2index[vocab.PAD_TOKEN],
                                                   'enc_input_extend_vocab': vocab.word2index[vocab.PAD_TOKEN],
                                                   'article_oovs': b'',
                                                   'dec_input': vocab.word2index[vocab.PAD_TOKEN],
                                                   'target': vocab.word2index[vocab.PAD_TOKEN],
                                                   'dec_len': -1,
                                                   'article': b'',
                                                   'abstract': b'',
                                                   'abstract_sents': b'',
                                                   'decoder_pad_mask': 0,
                                                   'encoder_pad_mask': 0},
                                   drop_remainder=True)

    def update(entry):
        return ({
                    "enc_input": entry["enc_input"],
                    "extended_enc_input": entry["enc_input_extend_vocab"],
                    "article_oovs": entry["article_oovs"],
                    "enc_len": entry["enc_len"],
                    "article": entry["article"],
                    "max_oov_len": tf.shape(entry["article_oovs"])[1],
                    "encoder_pad_mask": entry["encoder_pad_mask"]
                },
                {
                    "dec_input": entry["dec_input"],
                    "dec_target": entry["target"],
                    "dec_len": entry["dec_len"],
                    "abstract": entry["abstract"],
                    "decoder_pad_mask": entry["decoder_pad_mask"]
                })

    dataset = dataset.map(update)

    return dataset

    在batch_generator中,使用tf.data.Dataset.from_generator()接口,能使使得数据利用生成器一条一条进行生成。在该方法中需要指定输出形状output_shape和输出类型output_types参数。并且,可以在生成数据的同时,进行一定的预处理,padded_batch来对齐长度信息,并将长度不够的词汇自动按照指定的值进行填充。

def example_generator(params, vocab, max_enc_len, max_dec_len, mode):

    if mode != 'test':

        dataset_x = tf.data.TextLineDataset(params[f'{mode}_seg_x_dir'])
        dataset_y = tf.data.TextLineDataset(params[f'{mode}_seg_y_dir'])

        train_dataset = tf.data.Dataset.zip((dataset_x, dataset_y)).take(count=10000)

        if mode == 'train':
            train_dataset = train_dataset.shuffle(10, reshuffle_each_iteration=True).repeat(1)

        for raw_record in train_dataset:

            start_decoding = vocab.word_to_index(vocab.START_DECODING)
            stop_decoding = vocab.word_to_index(vocab.STOP_DECODING)

            article = raw_record[0].numpy().decode('utf-8')
            article_words = article.split()[:max_enc_len]

            enc_input = [vocab.word_to_index(w) for w in article_words]
            enc_input_extend_vocab, article_oovs = article_to_index(article_words, vocab)

            # add start and stop flag
            enc_input = get_enc_inp_targ_seqs(enc_input,
                                              max_enc_len,
                                              start_decoding,
                                              stop_decoding)

            enc_input_extend_vocab = get_enc_inp_targ_seqs(enc_input_extend_vocab,
                                                           max_enc_len,
                                                           start_decoding,
                                                           stop_decoding)

            # mark长度
            enc_len = len(enc_input)
            # 添加mark标记
            encoder_pad_mask = [1 for _ in range(enc_len)]

            abstract = raw_record[1].numpy().decode('utf-8')
            abstract_words = abstract.split()
            abs_ids = [vocab.word_to_index(w) for w in abstract_words]

            dec_input, target = get_dec_inp_targ_seqs(abs_ids,
                                                      max_dec_len,
                                                      start_decoding,
                                                      stop_decoding)

            if params['pointer_gen']:
                abs_ids_extend_vocab = abstract_to_index(abstract_words, vocab, article_oovs)
                _, target = get_dec_inp_targ_seqs(abs_ids_extend_vocab,
                                                  max_dec_len,
                                                  start_decoding,
                                                  stop_decoding)
            # mark长度
            dec_len = len(target)
            # 添加mark标记
            decoder_pad_mask = [1 for _ in range(dec_len)]

            output = {
                "enc_len": enc_len,
                "enc_input": enc_input,
                "enc_input_extend_vocab": enc_input_extend_vocab,
                "article_oovs": article_oovs,
                "dec_input": dec_input,
                "target": target,
                "dec_len": dec_len,
                "article": article,
                "abstract": abstract,
                "abstract_sents": abstract,
                "decoder_pad_mask": decoder_pad_mask,
                "encoder_pad_mask": encoder_pad_mask
            }

            yield output
    else:
        test_dataset = tf.data.TextLineDataset(params['valid_seg_x_dir'])
        for raw_record in test_dataset:
            article = raw_record.numpy().decode('utf-8')
            article_words = article.split()[: max_enc_len]
            enc_len = len(article_words)

            enc_input = [vocab.word_to_index(w) for w in article_words]
            enc_input_extend_vocab, article_oovs = article_to_index(article_words, vocab)

            # 添加mark标记
            encoder_pad_mask = [1 for _ in range(enc_len)]

            output = {
                "enc_len": enc_len,
                "enc_input": enc_input,
                "enc_input_extend_vocab": enc_input_extend_vocab,
                "article_oovs": article_oovs,
                "dec_input": [],
                "target": [],
                "dec_len": params['max_dec_len'],
                "article": article,
                "abstract": '',
                "abstract_sents": '',
                "decoder_pad_mask": [],
                "encoder_pad_mask": encoder_pad_mask
            }
            # 每一批的数据都一样阿, 是的是为了beam search
            if params["decode_mode"] == "beam":
                for _ in range(params["batch_size"]):
                    yield output
            elif params["decode_mode"] == "greedy":
                yield output
            else:
                print("shit")

    在example_generator函数中,就是对具体数据的处理。其中包括:添加开始、结束标志()。而在PGN模型中,一个很重要的部分就是,PGN模型的复制(copy)能力。而PGN的复制能力,使得该模型具有一定的解决oov问题的能力,具体体现在哪里呢?主要是pointer network复制的词,来自输入数据(input text),所以一定程度上能得到出现在输入数据(input text),但是不在词汇表中的词。说到底,是在一定程度上利用了在构建词汇表时,过滤掉的低频词等。

    PGN模型中另一个重要的点在于,其最终预测的概率分布是 词汇表长度的概率分布 + 输出数据(input text)的attention 的分布。(个人理解,同时出现在 词汇表 和input text的词的概率会被增大;按照贪心解码取概率的最大的1个或者K个词的思路,得到原始词汇表以外的词概率较小,因此该模型只能在一定程度上缓解oov问题。)

    代码中实现上述两个内容的基础时在对数据进行处理时,函数article_to_index和abstract_to_index两个函数中,将原来标注为UNK的单词,重新标注为出现在输入数据中的词(即较少了UNK)。

def article_to_index(article_words, vocab):

    oov_words = []
    extend_vocab_index = []

    unk_index = vocab.UNKNOWN_TOKEN_INDEX

    for word in article_words:
        word_index = vocab.word_to_index(word)
        if word_index == unk_index:
            if word not in oov_words:
                oov_words.append(word)

            oov_num = oov_words.index(word)
            extend_vocab_index.append(vocab.size() + oov_num)
        else:
            extend_vocab_index.append(word_index)

    return extend_vocab_index, oov_words

    将原来的词汇表进行扩展的代码样例如上,这里只展示了对训练数据X的处理,对于y的处理类似,后续完整代码可见。

数据保存部分的代码优化

# 构造模型保存管理器
checkpoint = tf.train.Checkpoint(PGN=model)
checkpoint_manager = tf.train.CheckpointManager(checkpoint, params['checkpoint_dir'], max_to_keep=5)

if checkpoint_manager.latest_checkpoint:
    checkpoint_manager.restore(checkpoint_manager.latest_checkpoint)
    params['trained_epoch'] = int(checkpoint_manager.latest_checkpoint[-1])
    logger.info(f'Building model by restoring {checkpoint_manager.latest_checkpoint}')
else:
    params['trained_epoch'] = 1
    logger.info('Building model from initial ...')

# 设置学习率
params['learning_rate'] *= np.power(0.95, params['trained_epoch'])
logger.info(f'Learning rate : {params["learning_rate"]}')

    上述代码的优化点在于对动态学习率的设置、自动加载上一次训练保存的最优模型,以及训练了多少epoch。

PGN模型构建

class PGN(keras.Model):

    def __init__(self, params):
        super(PGN, self).__init__()
        self.embedding_matrix = load_embedding_matrix(max_vocab_size=params['vocab_size'])

        self.vocab_size = params['vocab_size']
        self.batch_size = params['batch_size']

        self.encoder = Encoder(self.embedding_matrix,
                               params['enc_units'],
                               params['batch_size'])

        self.decoder = Decoder(self.embedding_matrix,
                               params['dec_units'],
                               params['batch_size'])

        self.pointer = Pointer()

    def call_one_step(self, dec_input, dec_hidden, enc_output, enc_pad_mask, use_coverage, prev_coverage):
        context_vector, dec_hidden, dec_x, prediction, attention_weights, coverage = self.decoder(dec_input,
                                                                                                  dec_hidden,
                                                                                                  enc_output,
                                                                                                  enc_pad_mask,
                                                                                                  prev_coverage,
                                                                                                  use_coverage)

        p_gens = self.pointer(context_vector, dec_hidden, dec_x)

        return prediction, dec_hidden, context_vector, attention_weights, p_gens, coverage

    def call(self, dec_input, dec_hidden, enc_output, enc_extended_input, batch_oov_len, enc_pad_mask, use_coverage,
             coverage=None):
        predictions = []
        attentions = []
        p_gens = []
        coverages = []

        for t in range(dec_input.shape[1]):
            final_dists, dec_hidden, context_vector, attention_weights, p_gen, coverage = self.call_one_step(
                dec_input[:, t],
                dec_hidden,
                enc_output,
                enc_pad_mask,
                use_coverage,
                coverage)

            coverages.append(coverage)
            predictions.append(final_dists)
            attentions.append(attention_weights)
            p_gens.append(p_gen)

        final_dists = _calc_final_dist(enc_extended_input,
                                       predictions,
                                       attentions,
                                       p_gens,
                                       batch_oov_len,
                                       self.vocab_size,
                                       self.batch_size)

        attentions = tf.stack(attentions, axis=1)

        return tf.stack(final_dists, 1), attentions, tf.stack(coverage, 1)


def _calc_final_dist(_enc_batch_extend_vocab, vocab_dists, attn_dists, p_gens, batch_oov_len, vocab_size, batch_size):
    """
    Calculate the final distribution, for the pointer-generator model
    Args:
    vocab_dists: The vocabulary distributions. List length max_dec_steps of (batch_size, vsize) arrays.
                The words are in the order they appear in the vocabulary file.
    attn_dists: The attention distributions. List length max_dec_steps of (batch_size, attn_len) arrays
    Returns:
    final_dists: The final distributions. List length max_dec_steps of (batch_size, extended_vsize) arrays.
    """
    # Multiply vocab dists by p_gen and attention dists by (1-p_gen)
    vocab_dists = [p_gen * dist for (p_gen, dist) in zip(p_gens, vocab_dists)]
    attn_dists = [(1 - p_gen) * dist for (p_gen, dist) in zip(p_gens, attn_dists)]

    # Concatenate some zeros to each vocabulary dist, to hold the probabilities for in-article OOV words
    extended_vsize = vocab_size + batch_oov_len  # the maximum (over the batch) size of the extended vocabulary
    extra_zeros = tf.zeros((batch_size, batch_oov_len))
    # list length max_dec_steps of shape (batch_size, extended_vsize)
    vocab_dists_extended = [tf.concat(axis=1, values=[dist, extra_zeros]) for dist in vocab_dists]

    # Project the values in the attention distributions onto the appropriate entries in the final distributions
    # This means that if a_i = 0.1 and the ith encoder word is w, and w has index 500 in the vocabulary,
    # then we add 0.1 onto the 500th entry of the final distribution
    # This is done for each decoder timestep.
    # This is fiddly; we use tf.scatter_nd to do the projection
    batch_nums = tf.range(0, limit=batch_size)  # shape (batch_size)
    batch_nums = tf.expand_dims(batch_nums, 1)  # shape (batch_size, 1)
    attn_len = tf.shape(_enc_batch_extend_vocab)[1]  # number of states we attend over
    batch_nums = tf.tile(batch_nums, [1, attn_len])  # shape (batch_size, attn_len)
    indices = tf.stack((batch_nums, _enc_batch_extend_vocab), axis=2)  # shape (batch_size, enc_t, 2)
    shape = [batch_size, extended_vsize]
    # list length max_dec_steps (batch_size, extended_vsize)
    attn_dists_projected = [tf.scatter_nd(indices, copy_dist, shape) for copy_dist in attn_dists]

    # Add the vocab distributions and the copy distributions together to get the final distributions
    # final_dists is a list length max_dec_steps; each entry is a tensor shape (batch_size, extended_vsize) giving
    # the final distribution for that decoder timestep
    # Note that for decoder timesteps and examples corresponding to a [PAD] token, this is junk - ignore.
    final_dists = [vocab_dist + copy_dist for (vocab_dist, copy_dist) in
                   zip(vocab_dists_extended, attn_dists_projected)]

    return final_dists

    PGN模型的构成主要有以下几部分:encoder、decoder、gen_pointer。本文中的实现,采用Teacher Forcing进行单词的生成过程。在函数_calc_final_dist()中,用于计算最终扩展后的概率分布。call_one_step()函数用于call调用,每次生成一个解码的词。下面将继续介绍模型的几个组件。

class Encoder(keras.Model):

    def __init__(self, embedding_matrix, enc_units, batch_size):
        super(Encoder, self).__init__()

        self.batch_size = batch_size
        self.enc_units = enc_units
        self.vocab_size, self.embedding_dim = embedding_matrix.shape

        self.embedding = keras.layers.Embedding(self.vocab_size,
                                                self.embedding_dim,
                                                weights=[embedding_matrix],
                                                trainable=False)

        self.gru = keras.layers.GRU(self.enc_units,
                                    return_state=True,
                                    return_sequences=True,
                                    recurrent_initializer='glorot_uniform')

        self.bidirectional_gru = keras.layers.Bidirectional(self.gru)

    def call(self, x, enc_hidden):

        x = self.embedding(x)  # x shape: batch_size * enc_units -> batch_size * 128

        # enc_output shape: batch * max_len * enc_unit
        enc_output, forward_state, backward_state = self.bidirectional_gru(x, initial_state=[enc_hidden, enc_hidden])

        # enc_hidden shape: batch_size * 256
        enc_hidden = keras.layers.concatenate([forward_state, backward_state], axis=-1)

        return enc_output, enc_hidden

    def initialize_hidden_state(self):

        return tf.zeros(shape=(self.batch_size, self.enc_units))

    encoder部分和前面基于seq2seq模型基本相同,差别在于此处使用了一个双向的gru进行编码表示。

def masked_attention(enc_pad_mask, attn_dist):

    attn_dist = tf.squeeze(attn_dist, axis=2)
    mask = tf.cast(enc_pad_mask, dtype=attn_dist.dtype)

    attn_dist *= mask

    mask_sum = tf.reduce_sum(attn_dist, axis=1)
    attn_dist = attn_dist / tf.reshape(mask_sum + 1e-12, [-1, 1])

    attn_dist = tf.expand_dims(attn_dist, axis=2)

    return attn_dist


class BahdanauAttention(keras.layers.Layer):
    def __init__(self, units):
        super(BahdanauAttention, self).__init__()

        self.W_s = keras.layers.Dense(units)
        self.W_h = keras.layers.Dense(units)
        self.W_c = keras.layers.Dense(units)
        self.V = keras.layers.Dense(1)

    def call(self, dec_hidden, enc_output, enc_pad_mask, use_coverage=False, pre_coverage=None):

        hidden_with_time_axis = tf.expand_dims(dec_hidden, 1)

        if use_coverage and pre_coverage is not None:
            score = self.V(tf.nn.tanh(self.W_s(enc_output) + self.W_h(hidden_with_time_axis) + self.W_c(pre_coverage)))

            attention_weights = tf.nn.softmax(score, axis=1)
            attention_weights = masked_attention(enc_pad_mask, attention_weights)
            coverage = attention_weights + pre_coverage
        else:
            score = self.V(tf.nn.tanh(self.W_s(enc_output) + self.W_h(hidden_with_time_axis)))

            attention_weights = tf.nn.softmax(score)
            attention_weights = masked_attention(enc_pad_mask, attention_weights)

            if use_coverage:
                coverage = attention_weights
            else:
                coverage = []

        context_vactor = attention_weights * enc_output
        context_vactor = tf.reduce_sum(context_vactor, axis=1)

        return context_vactor, tf.squeeze(attention_weights, -1), coverage

    Attention的计算一般可分为三个步骤:1.计算Attenrion score 2. softmax 3. reduce_sum。此处Attention过程稍微和之前的实现有所不同。主要体现在:1.计算softmax时,考虑了mask部分。在masked_attention函数中计算attention的时候,被mask的部分不参与到计算。2.考虑前一次计算attention时得到的向量。通过指定use_coverage参数指定是否使用收敛机制。

class Decoder(keras.Model):

    def __init__(self, embedding_matrix, dec_units, batch_size):
        super(Decoder, self).__init__()
        self.batch_size = batch_size
        self.dec_units = dec_units
        self.vocab_size, self.embedding_dim = embedding_matrix.shape

        self.embedding = keras.layers.Embedding(self.vocab_size,
                                                self.embedding_dim,
                                                weights=[embedding_matrix],
                                                trainable=False)
        self.cell = keras.layers.GRUCell(units=self.dec_units, recurrent_initializer='glorot_uniform')

        self.fc = keras.layers.Dense(self.vocab_size, activation=keras.activations.softmax)

        self.attention = BahdanauAttention(self.dec_units)

    def call(self, dec_input, dec_hidden, enc_output, enc_pad_mask, pre_coverage, use_covarage=True):

        dec_x = self.embedding(dec_input)

        dec_output, [dec_hidden] = self.cell(dec_x, [dec_hidden])

        context_vector, attention_weights, coverage = self.attention(dec_hidden,
                                                                     enc_output,
                                                                     enc_pad_mask,
                                                                     use_covarage,
                                                                     pre_coverage)

        dec_output = tf.concat([dec_output, context_vector], axis=-1)
        prediction = self.fc(dec_output)

        return context_vector, dec_hidden, dec_x, prediction, attention_weights, coverage

    decoder的部分一个比较重要的点在于,解码时是根据时间步(timestpes)一步一步进行解码,上一个时间步的输出和attention向量,会参与到下一步attention向量的计算,因此此处采用cell级别的lstm实现。

class Pointer(keras.layers.Layer):

    def __init__(self):
        super(Pointer, self).__init__()

        self.w_s_reduce = keras.layers.Dense(1)
        self.w_i_reduce = keras.layers.Dense(1)
        self.w_c_reduce = keras.layers.Dense(1)

    def call(self, context_vector, dec_hidden, dec_inp):

        return tf.nn.sigmoid(self.w_s_reduce(dec_hidden) +
                             self.w_c_reduce(context_vector) +
                             self.w_i_reduce(dec_inp))

    根据PGN模型的理论,在计算最终概率分布时,采用两个概率分布加权进行最终概率分布计算,而以多大概率进行加权呢?Pointer类主要用于计算这个值。到此为止,模型部分的细节基本介绍完毕。

模型训练以及评估

def train_model(model, train_dataset, valid_dataset, params, checkpoint_manager):
    epochs = params['epochs']

    optimizer = keras.optimizers.Adagrad(learning_rate=params['learning_rate'],
                                         initial_accumulator_value=params['adagrad_init_acc'],
                                         clipnorm=params['max_grad_norm'],
                                         epsilon=params['eps'])

    best_loss = 100
    for epoch in range(epochs):
        start = time.time()
        enc_hidden = model.encoder.initialize_hidden_state()

        total_loss = 0.
        total_log_loss = 0.
        total_cov_loss = 0.
        step = 0
        for encoder_batch_data, decoder_batch_data in train_dataset:

            batch_loss, log_loss, cov_loss = train_step(model,
                                                        enc_hidden,
                                                        encoder_batch_data['enc_input'],
                                                        encoder_batch_data['extended_enc_input'],
                                                        encoder_batch_data['max_oov_len'],
                                                        decoder_batch_data['dec_input'],
                                                        decoder_batch_data['dec_target'],
                                                        enc_pad_mask=encoder_batch_data['encoder_pad_mask'],
                                                        dec_pad_mask=decoder_batch_data['decoder_pad_mask'],
                                                        params=params,
                                                        optimizer=optimizer,
                                                        mode='train')

            step += 1
            total_loss += batch_loss
            total_log_loss += log_loss
            total_cov_loss += cov_loss
            if step % 50 == 0:
                if params['use_coverage']:

                    print('Epoch {} Batch {} avg_loss {:.4f} log_loss {:.4f} cov_loss {:.4f}'.format(epoch + 1,
                                                                                                     step,
                                                                                                     total_loss / step,
                                                                                                     total_log_loss / step,
                                                                                                     total_cov_loss / step))
                else:
                    print('Epoch {} Batch {} avg_loss {:.4f}'.format(epoch + 1,
                                                                     step,
                                                                     total_loss / step))

            valid_total_loss, valid_total_cov_loss, valic_total_log_loss = evaluate(model, valid_dataset, params)
            print('Epoch {} Loss {:.4f}, valid Loss {:.4f}'.format(epoch + 1, total_loss / step, valid_total_loss))
            print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

            if valid_total_loss < best_loss:
                best_loss = valid_total_loss
                ckpt_save_path = checkpoint_manager.save()
                print('Saving checkpoint for epoch {} at {}, best valid loss {}'.format(epoch + 1,
                                                                                        ckpt_save_path,
                                                                                        best_loss))


def train_step(model, enc_hidden, enc_input, extend_enc_input, max_oov_len, dec_input, dec_target, enc_pad_mask, dec_pad_mask, params, optimizer=None, mode='train'):

    with tf.GradientTape() as tape:

        # encoder,逐个预测
        enc_output, enc_hidden = model.encoder(enc_input, enc_hidden)

        # decoder
        dec_hidden = enc_hidden
        final_dists, attentions, coverages = model(dec_input, dec_hidden, enc_output, extend_enc_input, max_oov_len, enc_pad_mask=enc_pad_mask, use_coverage=params['use_coverage'], coverage=None)

        batch_loss, log_loss, cov_loss = calc_loss(dec_target, final_dists, dec_pad_mask, attentions, params['cov_loss_wt'], params['eps'])

        if mode == 'train':
            variables = (model.encoder.trainable_variables + model.decoder.trainable_variables + model.pointer.trainable_variables)
            gradients = tape.gradient(batch_loss, variables)
            optimizer.apply_gradients(zip(gradients, variables))

        return batch_loss, log_loss, cov_loss

    训练过程的整理框架基本维持不变。差别在于计算损失函数的计算。

def calc_loss(real, pred, dec_mask, attentions, cov_loss_wt, eps):

    log_loss = pgn_log_loss_function(real, pred, dec_mask, eps)

    cov_loss = _coverage_loss(attentions, dec_mask)

    return log_loss + cov_loss_wt * cov_loss, log_loss, cov_loss

    损失函数的计算包括两个部分:一个是pgn模型原有的损失,一个是使用coverage机制时带来的损失。并将最终损失做一个加权。

def pgn_log_loss_function(real, final_dists, padding_mask, eps):

    loss_per_step = []
    batch_nums = tf.range(0, limit=real.shape[0])
    final_dists = tf.transpose(final_dists, perm=[1, 0, 2])
    for dec_step, dist in enumerate(final_dists):
        targets = real[:, dec_step]
        indices = tf.stack((batch_nums, targets), axis=1)
        gold_probs = tf.gather_nd(dist, indices)
        losses = -tf.math.log(gold_probs + eps)
        loss_per_step.append(losses)

    _loss = _mask_and_avg(loss_per_step, padding_mask)
    return _loss
def _coverage_loss(attn_dists, padding_mask):

    attn_dists = tf.transpose(attn_dists, perm=[1, 0, 2])
    coverage = tf.zeros_like(attn_dists[0])

    covlosses = []
    for a in attn_dists:

        covloss = tf.reduce_sum(tf.minimum(a, coverage), [1])
        covlosses.append(covloss)

        coverage += a
    coverage_loss = _mask_and_avg(covlosses, padding_mask)
    return coverage_loss


def _mask_and_avg(values, padding_mask):
    padding_mask = tf.cast(padding_mask, dtype=values[0].dtype)
    dec_lens = tf.reduce_sum(padding_mask, axis=1)
    values_per_step = [v * padding_mask[:, dec_step] for dec_step, v in enumerate(values)]
    values_per_ex = s文本摘要-05-性能提升之PGN模型.mdum(values_per_step) / dec_lens

    return tf.reduce_mean(values_per_ex)

完整代码

代码

posted @ 2021-07-28 22:14  温良Miner  阅读(465)  评论(1编辑  收藏  举报
分享到: