LSTM实现文本情感分类demo

import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torchtext import data,datasets

class Args:
    max_vocab_size = 25000 #词表最大规模
    n_labels = 5
    epochs = 5
    embedding_dim = 300
    hidden_dim = 512
    n_layers = 3
    batch_size = 64
    display_freq = 50
    lr = 0.01

args = Args()
TEXT = data.Field()
LABEL = data.LabelField(dtype=torch.float)
train_data,valid_data,test_data = datasets.SST.splits(
    TEXT,LABEL,fine_grained=True
)
TEXT.build_vocab(
    train_data,
    max_size = args.max_vocab_size,
    vectors="glove.6B.300d",
    unk_init = torch.Tensor.normal_
)

LABEL.build_vocab(train_data)
device='cuda'
train_iter,valid_iter,test_iter=data.BucketIterator.splits(
    (train_data,valid_data,test_data),
    batch_size = args.batch_size,
    device = device
)
input_dim = len(TEXT.vocab)
output_dim = args.n_labels

class Model(nn.Module):
    def __init__(self,
                in_dim,
                emb_dim,
                hid_dim,
                out_dim,
                n_layer):
        super(Model,self).__init__()
        self.embedding = nn.Embedding(in_dim,emb_dim)
        self.rnn = nn.LSTM(emb_dim,hid_dim,n_layer)
        self.linear = nn.Linear(hid_dim,out_dim)
        self.n_layer = n_layer
        self.hid_dim = hid_dim
    def forward(self,text):
        embedded = self.embedding(text) #获取向量表示
        h0 = embedded.new_zeros(
            self.n_layer,embedded.size(1),self.hid_dim
        )
        c0 = embeded.new_zeros(
            self.n_layer,embedded.size(1),self.hid_dim
        )
        output,(hn,cn) = self.rnn(embedded,(h0,c0))
        return self.linear(output[-1])
    
    
model = Model(input_dim,args.embedding_dim,args.hidden_dim,output_dim,args.n_layers)
pretrained_embeddings = TEXT.vocab.vectors
model.embedding.weight.data.copy_(pretrained_embeddings)
model.to(device)
optimizer = optim.Adam(
    model.parameters(),lr=args.lr
)
def train(epoch,model,iterator,optimizer):
    loss_list = []
    acc_list = []
    model.train()
    
    for i,batch in tqdm(enumerate(iterator),total = len(iterator)):
        optimizer.zero_grad()
        text = batch.text.to(device)
        label = batch.label.long().to(device)
        predictions = model(text)
        loss = F.cross_entropy(predictions,label)
        loss.backward()
        optimizer.step()
        
        acc = (predictions.max(1)[1] == label).float().mean()
        loss_list.append(loss.item())
        acc_list.append(acc.item())
        
        if i % args.display_freq == 0:
            print("Epoch %02d,Iter [%03d/%03d],"
                 "train loss = %.4f,train acc = %.4f" % 
                 (epoch,i,len(iterator),np.mean(loss_list),np.mean(acc_list)))
            loss_list.clear()
            acc_list.clear()

def evaluate(epoch,model,iterator):
    val_loss = 0
    val_acc = 0
    model.eval()
    with torch.no_grad():
        for batch in iterator:
            text = batch.text.to(device)
            label = batch.label.long().to(device)
            predictions = model(text)
            loss = F.cross_entropy(predictions,label)
            acc = (predictions.max(1)[1] == label).float().mean()
            val_loss += loss.item()
            val_acc += acc_item()
    
    val_loss = val_loss/len(iterator)
    val_acc = val_acc/len(iterator)
    print('...Epoch %02d,val loss = %.4f,val acc = %.4f' %(
    epoch,val_loss,val_acc))
    return val_loss,val_acc

best_acc = 0
best_epoch = -1
for epoch in range(1,args.epochs+1):
    train(epoch,model,train_iter,optimizer)
    valid_loss,valid_acc = evaluate(epoch,model,valid_iter)
    if valid_acc > best_acc:
        best_acc = valid_acc
        best_epoch = epoch
        torch.save(
            model.state_dict(),
            'best-model.pth'
        )

print('Test best model @ Epoch %02d' % best_epoch)
model.load_state_dict(torch.load('best-model.pth'))
test_loss,test_acc = evaluate(epoch,model,test_iter)
print('Finally,test loss = %.4f,test acc = %.4f' %(test_loss,test_acc))
posted @ 2024-05-06 10:56  Sun-Wind  阅读(10)  评论(0编辑  收藏  举报