基于TensorFLow2.0的RNN文本生成学习

  • 加载数据,这里可以使用自己的数据集
# 加载数据
def get_data_from_file(train_file, batch_size, seq_size):
    with open(train_file, encoding='utf-8') as f:
        text = f.read()
        
    text = text.split()
    
    # 获取频繁词
    word_counts = Counter(text)
    # 按key降序排列
    sorted_vocab = sorted(word_counts, key=word_counts.get, reverse=True)
    # 创建词汇表
    int_to_vocab = {k: w for k, w in enumerate(sorted_vocab)}
    vocab_to_int = {w: k for k, w in int_to_vocab.items()}
    # 词汇表大小
    n_vocab = len(int_to_vocab)

    print('Vocabulary size', n_vocab)
    
    # 输入文本
    int_text = [vocab_to_int[w] for w in text]
    # 训练总需的批次数
    num_batches = int(len(int_text) / (seq_size * batch_size))
    in_text = int_text[:num_batches * batch_size * seq_size]
    
    # 后移
    out_text = np.zeros_like(in_text)
    out_text[:-1] = in_text[1:]
    out_text[-1] = in_text[0]
    
    in_text = np.reshape(in_text, (-1, seq_size))
    out_text = np.reshape(out_text, (-1, seq_size))
    
    return int_to_vocab, vocab_to_int, n_vocab, in_text, out_text
  • 构建模型:由于这里只是原理的简单实现,因此只使用了一层单向的LSTM。在实际使用中,为了模型的效果可以使用多层双向的LSTM/GRU或是Transformer。
# 构建模型,这里使用的是单层的LSTM
class RNNModule(tf.keras.Model):
    def __init__(self, n_vocab, embedding_size, lstm_size):
        super(RNNModule, self).__init__()
        self.lstm_size = lstm_size
        
        # embedding shape:n_vocab * embedding_size
        self.embedding = tf.keras.layers.Embedding(n_vocab, embedding_size)
        self.lstm = tf.keras.layers.LSTM(lstm_size, return_state=True, return_sequences=True)
        self.dense = tf.keras.layers.Dense(n_vocab)
    
    # x shape: batch_size * seq_size
    def call(self, x, prev_state):
        embed = self.embedding(x)
        output, state_h, state_c = self.lstm(embed, prev_state)
        
        logits = self.dense(output)
        preds = tf.nn.softmax(logits)
        
        return logits, preds, (state_h, state_c)

    # 状态初始化
    def zero_state(self, batch_size):
        return [tf.zeros([batch_size, self.lstm_size]),
                tf.zeros([batch_size, self.lstm_size])]
  • 训练函数
@tf.function
def train_func(inputs, targets, model, state, loss_func, optimizer):
  with tf.GradientTape() as tape:
      logits, _, state = model(inputs, state)

      loss = loss_func(targets, logits)

      gradients = tape.gradient(loss, model.trainable_variables)
      optimizer.apply_gradients(zip(gradients, model.trainable_variables))
      
      return loss
  • 训练过程
# 训练
def train():
    int_to_vocab, vocab_to_int, n_vocab, in_text, out_text = get_data_from_file(args.train_file, args.batch_size, args.seq_size)
    # 文本长度
    len_data = in_text.shape[0]
    # 每个epoch训练的步数
    steps_per_epoch = len_data // args.batch_size
    
    # 创建dataset对象
    dataset = tf.data.Dataset.from_tensor_slices((in_text, out_text)).shuffle(args.buffer_size)
    dataset = dataset.batch(16, drop_remainder=True)
    
    # 加载模型、Encoder状态初始化
    model = RNNModule(n_vocab, args.embedding_size, args.lstm_size)
    state = model.zero_state(args.batch_size)
    
    # 选择优化器和损失项
    optimizer = tf.keras.optimizers.Adam()
    loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    
    # 开始
    for epoch in range(args.num_epochs):
        start = time.time()
        total_loss = list()
        state = model.zero_state(args.batch_size)

        for (batch, (inputs, targets)) in enumerate(dataset.take(steps_per_epoch)):
            
            
            # print ('Inputs shape: {}'.format(inputs.shape)) 
            # print ('targets shape: {}'.format(targets.shape))
            # Inputs shape: (16, 32)
            # targets shape: (16, 32)
            
            loss = train_func(inputs, targets, model, state, loss_func, optimizer)
            total_loss.append(loss)
            
            if batch % 100 == 0:
                print('Epoch: {}/{}'.format(epoch, args.num_epochs),
                      'Batch--> {}'.format(batch),
                      'Loss--> {}'.format(loss.numpy()))
                
            if batch % 300 == 0:
                predict(model, vocab_to_int, int_to_vocab, n_vocab)
                model.save_weights(checkpoint_prefix.format(epoch = epoch))
                
        plot_loss(total_loss)
        print ('Total time of this epoch is {}:'.format(time.time() - start))
posted @ 2021-01-07 16:21  Mr.zzz  阅读(44)  评论(0编辑  收藏  举报