NLP文本分类学习笔记3.1:基于DPCNN的文本分类
深度卷积网络DPCNN#
在NLP文本分类学习笔记3中介绍了CNN的结构和如何用于文本分类,但是也存在一些问题(在之后将看到)。
在这篇论文Deep Pyramid Convolutional Neural Networks for Text Categorization中提出了DPCNN模型,其结构如下图所示
其中主要使用的部分方法如下:
- 参考:
https://blog.csdn.net/guleileo/article/details/87035446
https://zhuanlan.zhihu.com/p/35457093 - 1/2池化:为了捕获序列更长距离间的联系,使用1/2池化(大小为3,步长为2的池化层),使得每次序列长度减少一半
- 残差连接:为了解决DPCNN,由于网络参数初始权重小,导致训练启动慢,以及梯度爆炸或消失的问题,使用残差连接,将输出直接连接在卷积的输出后
DPCNN用于文本分类的主要思想#
1、CNN用于文本分类,实际上捕获的是N-gram的特征,而且是2-gram,3-gram这类特征,这其实使得它与传统的提取N-gram特征的文本分类没有什么区别。并且CNN捕捉不到文本全局的联系。
2、DPCNN使用1/2池化,使得每次都在N-gram特征上再进行N-gram,捕获到更多的特征种类。
pytorch实现基于DPCNN的文本分类#
在10分类的测试集上准确率为85.61%,这里仅介绍实现的结构,更多代码的详细说明见NLP文本分类学习笔记0
- 词向量输入到region_embedding层,在这一层经过一次卷积,卷积核为250,卷积核大小为3(函数自动生成卷积核大小为3*通道数,这里也就是词向量维数embed_size),步长使用默认的1,序列左右两边都进行1次填充。最后再归一化,relu函数激活
- 之后经过两轮卷积,每轮卷积开始时都使用归一化,relu函数激活处理,卷积核参数如上,最后输出与region embedding的输出进行残差连接
- 然后进行不断循环,直至序列长度小于等于2,每次循环开始时,经过1/2池化,池化核大小为3,步长为2,序列左右两侧都进行填充,将池化后的输出也经过两轮卷积,每轮卷积开始时都使用归一化,relu函数激活处理,卷积核参数也如上,最后输出与池化后的输出进行残差连接
- 最后再进行一次池化,然后连接到全连接层,进行分类。
关于nn.Conv1d和nn.Conv2d#
下述代码与之前笔记中代码不一样的是使用一维卷积nn.Conv1d。其实对于文本任务使用一维卷积和二维卷积都可,使用二维卷积只不过把卷积核的一维设置为词向量维度即可,并且还要使用unsqueeze加上一个维度。
因为在pytorch的文档中
nn.Conv2d输入为(批次大小batch_size,通道数,文本序列长度,词向量维数)
nn.Conv1d为(批次大小batch_size,通道数,文本序列长度)这里的通道数指词向量维数
而pytorch一维卷积的过程如下,以前我一直理解错误。原来一维与二维的区别不是指卷积核维度的区别,是指卷积核移动的维度,一维的卷积核是沿一个维度移动卷积的,这个维度就是输入的最后一个维度,在这里就是文本的长度。
import json
import pickle
import torch
import torch.nn as nn
import numpy as np
class Config(object):
def __init__(self, embedding_pre):
self.embedding_path = 'data/embedding.npz'
self.embedding_model_path = "mymodel/word2vec.model"
self.train_path = 'data/train.df' # 训练集
self.dev_path = 'data/valid.df' # 验证集
self.test_path = 'data/test.df' # 测试集
self.class_path = 'data/class.json' # 类别名单
self.vocab_path = 'data/vocab.pkl' # 词表
self.save_path ='mymodel/dpcnn.pth' # 模型训练结果
self.embedding_pretrained = torch.tensor(np.load(self.embedding_path, allow_pickle=True)["embeddings"].astype(
'float32')) if embedding_pre == True else None # 预训练词向量
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备
self.num_classes = len(json.load(open(self.class_path, encoding='utf-8'))) # 类别数
self.n_vocab = 0 # 词表大小,在运行时赋值
self.epochs = 10 # epoch数
self.batch_size = 128 # mini-batch大小
self.maxlen = 32 # 每句话处理成的长度(短填长切)
self.learning_rate = 1e-3 # 学习率
self.embed_size = self.embedding_pretrained.size(1) \
if self.embedding_pretrained is not None else 200 # 字向量维度
self.num_filters = 250 # 卷积核数量(channels数)
class Model(nn.Module):
def __init__(self, config):
super(Model, self).__init__()
if config.embedding_pretrained is not None:
self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)
else:
vocab = pickle.load(open(config.vocab_path, 'rb'))
config.n_vocab=len(vocab.dict)
self.embedding = nn.Embedding(config.n_vocab, config.embed_size, padding_idx=config.n_vocab - 1)
self.batchNorm=nn.BatchNorm1d(config.num_filters)
self.relu=nn.ReLU()
self.max_pool = nn.MaxPool1d(kernel_size=3, stride=2,padding=1)
self.conv2 = nn.Conv1d(config.num_filters, config.num_filters,3,padding=1)
self.fc=nn.Linear(config.num_filters,config.num_classes)
self.Region_embedding=nn.Sequential(
nn.Conv1d(config.embed_size, config.num_filters, 3, padding=1),
nn.BatchNorm1d(config.num_filters),
nn.ReLU()
)
self.conv1=nn.Sequential(
nn.BatchNorm1d(config.num_filters),
nn.ReLU(),
nn.Conv1d(config.num_filters, config.num_filters, 3, padding=1),
nn.BatchNorm1d(config.num_filters),
nn.ReLU(),
nn.Conv1d(config.num_filters, config.num_filters, 3, padding=1),
)
def repeat_block(self,input):
x_pool=self.max_pool(input)
x=self.batchNorm(x_pool)
x = self.relu(x)
x=self.conv2(x)
x=self.batchNorm(x)
x=self.relu(x)
out=self.conv2(x)
return out+x_pool
def forward(self,input):
x=self.embedding(input)
x = x.permute(0, 2, 1)
region_embed=self.Region_embedding(x)
x=self.conv1(region_embed)
x=x+region_embed
while x.size()[2] > 2:
x = self.repeat_block(x)
x=self.max_pool(x)
x = x.squeeze()
x = self.fc(x)
return x
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· PowerShell开发游戏 · 打蜜蜂
· 在鹅厂做java开发是什么体验
· 百万级群聊的设计实践
· WPF到Web的无缝过渡:英雄联盟客户端的OpenSilver迁移实战
· 永远不要相信用户的输入:从 SQL 注入攻防看输入验证的重要性