NLP文本分类学习笔记7.1:基于ERNIE的文本分类

ERNIE#

相关链接:ERNIE官方使用介绍ERNIE项目地址
基于transformer的encoder,主要思想是将文本中已有的知识融入到模型训练中,因此采用实体mask的方式(实体指人名,地名等词)

预训练#

模型结构图如下所示

文本中已有的知识主要有人名,地名等实体,这些词本来就蕴含一些信息,而采用bert那种mask方式,如通过哈和滨预测中间的尔,显然多此一举,且没有关注哈尔滨这个词本来含有的信息。

  • ERNIE使用多个T-encoder,还是像bert一样输入token embedding,训练得到文本序列中的信息。其中T-encoder为transformer的encoder
  • 再使用多个K-encoder,将文本中的实体embedding输入与T-encoder输入“拼接”在一起,最后输出。
  • 实体embedding采用知识嵌入模型TransE得到(TransE这组要思想是构造实体向量和关系向量,不断使两个实体向量相加接近关系向量),然后实体embedding通过多头注意力机制提取信息
  • T-encoder的输出w再经过多头注意力机制后与实体提取的信息e“拼接”,经过information fusion层,最后得到输出
  • “拼接”方式采用下图中公式一,w经过全连接层,e经过全连接层,两者相加(实体要拼接到最开始的那个token上,如实体哈尔滨要拼到哈上),通过GELU激活函数,得到h
  • 在information fusion层,再用h,分别通过全连接层得到新的w和e

  • 如果这段文本没有实体信息,就采用下述方法

  • 预训练的任务是用5%时间,随机替换实体,让模型预测正确的实体,15%的时间,随机mask实体知识与token拼接的信息,用模型去预测这个信息,剩下的时间不变

微调#

以文本分类为例:与bert时相同,先对输入的句子按字进行切分,最后将[cls]对应的输出用作分类

pytorch实现基于ERNIE的文本分类#

使用Hugging Face的预训练模型nghuyong/ernie-1.0 ,在10分类任务上准确率为76.98%,更多代码详情见NLP文本分类学习笔记0
结构代码 myERNIE.py

Copy
import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModel class Config(object): def __init__(self): self.pre_bert_path="nghuyong/ernie-1.0" self.train_path = 'data/dataset_train.csv' # 训练集 self.dev_path = 'data/dataset_valid.csv' # 验证集 self.test_path = 'data/test.csv' # 测试集 self.class_path = 'data/class.json' # 类别名单 self.save_path ='mymodel/ernie.pth' # 模型训练结果 self.num_classes=10 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 self.epochs = 10 # epoch数 self.batch_size = 128 # mini-batch大小 self.maxlen = 32 # 每句话处理成的长度(短填长切) self.learning_rate = 5e-4 # 学习率 self.hidden_size=768 self.tokenizer = AutoTokenizer.from_pretrained(self.pre_bert_path) class Model(nn.Module): def __init__(self, config): super(Model, self).__init__() self.ernie=AutoModel.from_pretrained(config.pre_bert_path) #设置不更新预训练模型的参数 for param in self.ernie.parameters(): param.requires_grad = False self.fc = nn.Linear(config.hidden_size, config.num_classes) def forward(self, input): out=self.ernie(input_ids =input['input_ids'],attention_mask=input['attention_mask'],token_type_ids=input['token_type_ids']) #只取最后一层CLS对应的输出 out = self.fc(out.pooler_output) return out

运行代码run.py

Copy
import json from mymodel import myBert,myAlbertl,myERNIE import mydataset import torch import pandas as pd from torch import nn,optim from torch.utils.data import DataLoader config=myERNIE.Config() label_dict=json.load(open(config.class_path,'r',encoding='utf-8')) # 加载训练,验证,测试数据集 train_df = pd.read_csv(config.train_path) #这里将标签转化为数字 train_ds=mydataset.GetLoader(train_df['review'],[label_dict[i] for i in train_df['cat']]) train_dl=DataLoader(train_ds,batch_size=config.batch_size,shuffle=True) valid_df = pd.read_csv(config.dev_path) valid_ds=mydataset.GetLoader(valid_df['review'],[label_dict[i] for i in valid_df['cat']]) valid_dl=DataLoader(valid_ds,batch_size=config.batch_size,shuffle=True) test_df = pd.read_csv(config.test_path) test_ds=mydataset.GetLoader(test_df['review'],[label_dict[i] for i in test_df['cat']]) test_dl=DataLoader(test_ds,batch_size=config.batch_size,shuffle=True) #计算准确率 def accuracys(pre,label): pre=torch.max(pre.data,1)[1] accuracy=pre.eq(label.data.view_as(pre)).sum() return accuracy,len(label) #导入网络结构 model=myERNIE.Model(config).to(config.device) #训练 criterion=nn.CrossEntropyLoss() optimizer=optim.Adam(model.parameters(),lr=config.learning_rate) best_loss=float('inf') for epoch in range(config.epochs): train_acc = [] for batch_idx,(data,target)in enumerate(train_dl): inputs = config.tokenizer(list(data),truncation=True, return_tensors="pt",padding=True,max_length=config.maxlen) model.train() out = model(inputs) loss=criterion(out,target) optimizer.zero_grad() loss.backward() optimizer.step() train_acc.append(accuracys(out,target)) train_r = (sum(tup[0] for tup in train_acc), sum(tup[1] for tup in train_acc)) print('当前epoch:{}\t[{}/{}]{:.0f}%\t损失:{:.6f}\t训练集准确率:{:.2f}%\t'.format( epoch, batch_idx, len(train_dl), 100. * batch_idx / len(train_dl), loss.data, 100. * train_r[0].numpy() / train_r[1] )) #每100批次进行一次验证 if batch_idx%100==0 and batch_idx!=0: model.eval() val_acc=[] loss_total=0 with torch.no_grad(): for (data,target) in valid_dl: inputs = config.tokenizer(list(data), truncation=True, return_tensors="pt", padding=True, max_length=config.maxlen) out = model(inputs) loss_total = criterion(out, target).data+loss_total val_acc.append(accuracys(out,target)) val_r = (sum(tup[0] for tup in val_acc), sum(tup[1] for tup in val_acc)) print('损失:{:.6f}\t验证集准确率:{:.2f}%\t'.format(loss_total/len(valid_dl),100. * val_r[0].numpy() / val_r[1])) #如果验证损失低于最好损失,则保存模型 if loss_total < best_loss: best_loss = loss_total torch.save(model.state_dict(), config.save_path) #测试 model.load_state_dict(torch.load(config.save_path)) model.eval() test_acc=[] with torch.no_grad(): for (data, target) in test_dl: inputs = config.tokenizer(list(data),truncation=True, return_tensors="pt",padding=True,max_length=config.maxlen) out = model(inputs) test_acc.append(accuracys(out, target)) test_r = (sum(tup[0] for tup in test_acc), sum(tup[1] for tup in test_acc)) print('测试集准确率:{:.2f}%\t'.format(100. * test_r[0].numpy() / test_r[1]))
posted @   启林O_o  阅读(1474)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· PowerShell开发游戏 · 打蜜蜂
· 在鹅厂做java开发是什么体验
· 百万级群聊的设计实践
· WPF到Web的无缝过渡:英雄联盟客户端的OpenSilver迁移实战
· 永远不要相信用户的输入:从 SQL 注入攻防看输入验证的重要性
点击右上角即可分享
微信分享提示
CONTENTS