NLP文本分类学习笔记4:基于RNN的文本分类

循环神经网络RNN#

RNN拥有一个环路,数据可以通过这个环路不断循环,因此拥有了记忆性,所以更针对序列数据。序列数据上一刻的输出和下一刻的数据一起作为新的输出,结构如下图所示,XtHt为t时刻的输入和输出,输入的序列数据为X1X2X3Xt

但是RNN试图学到序列数据中所有的长时间的依赖关系,这导致网络很深。又因为采用反向传播算法更新网络的参数时,使用链式法则求导,所以会出现梯度接近零,无法更新网络的参数的情况。如果网络参数初始值过大,又可能导致梯度指数级增长,使网络参数每次都大幅更新。这就是梯度消失和梯度爆炸的问题。因此一般采用长度记忆神经网络LSTM

长短记忆神经网络LSTM#

LSTM同样采用RNN的结构,只不过通过记忆门,进行选择性记忆,并且使用对应元素的乘,而不是矩阵乘,这些都缓解梯度消失和爆炸。
记忆门就是0和1的序列,1对应的数据就会被记住(保留),0对应的数据就会被忘记(忽略)。
如图所示,LSTM采用了三个门:忘记门Ft,输入门It,输出门Ot

  • 忘记门决定忽略掉保留的记忆中对这次没有作用的部分
  • 输入门决定忽略输入的信息中对这次没有作用的部分
  • 输出门决定忽略这次的结果中对下次没有作用的部分
  • Ct是LSTM中的记忆单元,保留有t时刻的记忆信息
  • C¯t是候选记忆单元
  • σ 为sigma激活函数
  • tanh为双曲正切激活函数
  • 图中的✖是对应元素的乘(点乘),不是矩阵乘

It=σ(XtWxi+Ht1Whi+bi)

Ft=σ(XtWxf+Ht1Whf+bf)

Ot=σ(XtWxo+Ht1Who+bo)

C¯t=tanh(XtWxc+Ht1Whc+bc)

Ct=FtCt1+ItC¯t

Ht=Ottanh(Ct

多层LSTM#

通过叠加多层LSTM对提升模型有一定效果,将下一层的输出作为上一层的输入,最后每一层都会有一个输出,2层LSTM如下图所示

双向LSTM#

对于一个序列不仅通过过去看未来,还通过未来看过去,从正反两个方向学习序列的特征,每个时刻的正反两个方向的输出拼接运算后作为这一时刻的总输出OT,如下图所示

门控循环单元GRU#

与LSTM相比,有一定的简化,LSTM有三个门:遗忘门,输入门,输出门,GRU只有两个门:重置门Rt和更新门Zt
对下列式子的主观理解为:

  • 重置门决定遗忘多少过去的信息Ht1
  • 候选记忆信息H¯t由新的输入Xt和经过重置门决定的过去信息Ht1得到
  • 更新门决定使用多少过去记忆信息Ht1和候选信息H¯t,来得到这一时刻的输出Ht

Rt=σ(XtWxr+Ht1Whr+br)

Zt=σ(XtWxz+Ht1Whz+bz)

H¯t=tanh(XtWxh+(RtHt1)Whh+bh)

Ht=ZtHt1+(1Zt)H¯t

基于LSTM的文本分类#

如下图所示,是一个两层的双向的LSTM的简化示意,将文本分词等处理后后,经过embedding层,转化为可学习的序列输入,经过LSTM处理后,只关注最后一层的最后输出Ht即可(也可以将其它层的最后输出Ht等通过加权等方式进行使用),最后将输出连接到全连接层后用于分类,其中在每层LSTM(除最后一层)之间加入dropout,可以防止模型的过拟合

主要思想#

文本数据是时间序列数据,前后之间相互联系,如“我的心情很(),因为我的玩具坏了”这句话,通过前面可以知道括号中可以填“好”“坏”这类的词,而不可能是“我的心情很(篮球)”之类,而通过后面,又可以知道填“坏”的可能性要大一些。通过上述模型能够捕获句子前后联系的特征。

pytorch实现基于LSTM的文本分类#

模型结构参数如下,对于10分类的任务达到了85.36%的准确率。也使用GRU进行尝试准确率为85.81%,与LSTM相差不大,进行关于代码更详细的说明参考:NLP文本分类学习笔记0:数据预处理及训练说明

  • 将nn.LSTM的参数batch_first设置为True时,其输入输出的第一维表示为batch大小,bidirectional设置为True表示为双向LSTM
  • 每个输入的批次为【128,32】,128为批次大小,32为句子填充截断后统一长度
  • 经过词嵌入层,数据变为【128,32,200】,200为word2vec预训练词向量维度
  • 之后输入到LSTM层,其中隐藏层大小设置为128,层数设置为2,即每个输出为128,最后的输出为【128,32,2*128】,128为batch大小,32为句子长度,2表示双向LSTM,128为输出维度,在pytorch中双向LSTM的前后向输出是拼在了一起(在文档中还提到了packed sequence这样的输出格式文档
  • 最后只将最后时刻的输出,输入到全连接层进行分类
Copy
import json import pickle import torch import torch.nn as nn import numpy as np class Config(object): """配置参数""" def __init__(self, embedding_pre): self.embedding_path = 'data/embedding.npz' self.embedding_model_path = "mymodel/word2vec.model" self.train_path = 'data/train.df' # 训练集 self.dev_path = 'data/valid.df' # 验证集 self.test_path = 'data/test.df' # 测试集 self.class_path = 'data/class.json' # 类别名单 self.vocab_path = 'data/vocab.pkl' # 词表 self.save_path ='mymodel/rnn.pth' # 模型训练结果 self.embedding_pretrained = torch.tensor(np.load(self.embedding_path, allow_pickle=True)["embeddings"].astype( 'float32')) if embedding_pre == True else None # 预训练词向量 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 self.dropout = 0.5 # 随机失活 self.num_classes = len(json.load(open(self.class_path, encoding='utf-8'))) # 类别数 self.n_vocab = 0 # 词表大小,在运行时赋值 self.epochs = 10 # epoch数 self.batch_size = 128 # mini-batch大小 self.maxlen = 32 # 每句话处理成的长度(短填长切) self.learning_rate = 1e-3 # 学习率 self.embed_size = self.embedding_pretrained.size(1) \ if self.embedding_pretrained is not None else 200 # 字向量维度 self.hidden_size = 128 # lstm隐藏层 self.num_layers = 2 # lstm层数 class Model(nn.Module): def __init__(self, config): super(Model, self).__init__() if config.embedding_pretrained is not None: self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False) else: vocab = pickle.load(open(config.vocab_path, 'rb')) config.n_vocab=len(vocab.dict) self.embedding = nn.Embedding(config.n_vocab, config.embed_size, padding_idx=config.n_vocab - 1) self.lstm = nn.LSTM(config.embed_size, config.hidden_size, config.num_layers, bidirectional=True, batch_first=True, dropout=config.dropout) #使用GRU #self.lstm = nn.GRU(config.embed_size, config.hidden_size, config.num_layers,bidirectional=True, batch_first=True, dropout=config.dropout) self.fc = nn.Linear(config.hidden_size * 2, config.num_classes) def forward(self, x): out = self.embedding(x) out, i = self.lstm(out) out = self.fc(out[:, -1, :]) # 句子最后时刻的 hidden state return out
posted @   启林O_o  阅读(406)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· PowerShell开发游戏 · 打蜜蜂
· 在鹅厂做java开发是什么体验
· 百万级群聊的设计实践
· WPF到Web的无缝过渡:英雄联盟客户端的OpenSilver迁移实战
· 永远不要相信用户的输入:从 SQL 注入攻防看输入验证的重要性
点击右上角即可分享
微信分享提示
CONTENTS