NLP文本分类学习笔记4.1:基于RCNN的文本分类

循环卷积神经网络RCNN#

1、CNN与RNN缺点

  • CNN通过窗口获取特征,窗口尺寸不合适就会捕获不到好特征,窗口也不能太大,这样就捕获不到全局的特征,所以它类似于传统的N-gram
  • RNN使用最后的输出作为特征,使得序列后的词会比前面的词更加重要,从而影响捕获准确的特征

2、CNN与RNN优点

  • CNN使用池化,能够捕获重要的特征
  • RNN处理序列有优势,能够捕获全局特征

所以Recurrent Convolutional Neural Networks for Text Classification这篇论文将两者优点结合起来,提出下图模型RCNN

  • 图中虚线圈出部分,实际上是一个双向循环网络(之后用双向LSTM实现,尽管论文中并不是,但也类似)
  • 之后将所有时刻的输出和输入的词向量拼接起来(即图中的y3(2)y4(2)等,图中并未表示完整),论文中拼接的公式为,其中cl(wi)cr(wi)为双向LSTM的两个输出,e(wi)为词向量

xi=[cl(wi);e(wi);cr(wi)]

  • 然后经过激活函数tanh(图中未画出,实现时采用relu)
  • 之后对每一维进行最大池化,组成新的特征向量
  • 最后连接全连接层实现分类

pytorch实现基于RCNN的文本分类#

对于10分类任务,在测试集分类准确率为87.27%,关于网络结构代码如下,更多代码详细介绍见NLP文本分类学习笔记0

Copy
import json import pickle import torch import torch.nn.functional as F import torch.nn as nn import numpy as np class Config(object): def __init__(self, embedding_pre): self.embedding_path = 'data/embedding.npz' self.embedding_model_path = "mymodel/word2vec.model" self.train_path = 'data/train.df' # 训练集 self.dev_path = 'data/valid.df' # 验证集 self.test_path = 'data/test.df' # 测试集 self.class_path = 'data/class.json' # 类别名单 self.vocab_path = 'data/vocab.pkl' # 词表 self.save_path ='mymodel/rcnn.pth' # 模型训练结果 self.embedding_pretrained = torch.tensor(np.load(self.embedding_path, allow_pickle=True)["embeddings"].astype( 'float32')) if embedding_pre == True else None # 预训练词向量 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 self.dropout = 0.5 # 随机失活 self.num_classes = len(json.load(open(self.class_path, encoding='utf-8'))) # 类别数 self.n_vocab = 0 # 词表大小,在运行时赋值 self.epochs = 10 # epoch数 self.batch_size = 128 # mini-batch大小 self.maxlen = 32 # 每句话处理成的长度(短填长切) self.learning_rate = 1e-3 # 学习率 self.embed_size = self.embedding_pretrained.size(1) \ if self.embedding_pretrained is not None else 200 # 字向量维度 self.hidden_size = 128 # lstm隐藏层 self.num_layers = 1 # lstm层数 class Model(nn.Module): def __init__(self, config): super(Model, self).__init__() if config.embedding_pretrained is not None: self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False) else: vocab = pickle.load(open(config.vocab_path, 'rb')) config.n_vocab=len(vocab.dict) self.embedding = nn.Embedding(config.n_vocab, config.embed_size, padding_idx=config.n_vocab - 1) self.lstm = nn.LSTM(config.embed_size, config.hidden_size, config.num_layers, bidirectional=True, batch_first=True, dropout=config.dropout) self.maxpool = nn.MaxPool1d(config.maxlen) self.fc = nn.Linear(config.hidden_size * 2 + config.embed_size, config.num_classes) def forward(self, x): embed = self.embedding(x) out, _ = self.lstm(embed) out = torch.cat((embed, out), 2) out = F.relu(out) out = out.permute(0, 2, 1) out = self.maxpool(out).squeeze() out = self.fc(out) return out
posted @   启林O_o  阅读(132)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· PowerShell开发游戏 · 打蜜蜂
· 在鹅厂做java开发是什么体验
· 百万级群聊的设计实践
· WPF到Web的无缝过渡:英雄联盟客户端的OpenSilver迁移实战
· 永远不要相信用户的输入:从 SQL 注入攻防看输入验证的重要性
点击右上角即可分享
微信分享提示
CONTENTS