NLP(十二):word2vec+siamese-BiLSTM计算文本相似度
一、模型my_bilstm.py
import torch from torch import nn class SiameseLSTM(nn.Module): def __init__(self, input_size): super(SiameseLSTM, self).__init__() self.lstm = nn.LSTM(input_size=input_size, hidden_size=64, num_layers=1, batch_first=True, bidirectional = True) self.fc = nn.Sequential( nn.Linear(256,200), nn.LeakyReLU(inplace = True), nn.Linear(200,1), ) def forward(self, data1, data2): out1, (h1, c1) = self.lstm(data1) out2, (h2, c2) = self.lstm(data2) pre1 = out1[:, -1, :] pre2 = out2[:, -1, :] pre = torch.cat([pre1,pre2],dim=1) out = self.fc(pre) return out if __name__ == '__main__': d1 = torch.rand(2, 16, 128) d2 = torch.rand(2, 16, 128) model = SiameseLSTM(128) model(d1, d2)
二、数据集my_dataset.py
import torch.utils.data as data class MyDataset(data.Dataset): def __init__(self, texta, textb, label): self.texta = texta self.textb = textb self.label = label def __getitem__(self, item): texta = self.texta[item] textb = self.textb[item] label = self.label[item] return texta, textb, label def __len__(self): return len(self.texta)
三、词嵌入
my_word2vec.py
from gensim.models.fasttext import FastText import torch import numpy as np import os class WordEmbedding(object): def __init__(self): parent_path = os.path.split(os.path.realpath(__file__))[0] self.root = parent_path[:parent_path.find("models")] # E:\personas\semantics\ self.word_fasttext = os.path.join(self.root, "checkpoints", "word2vec", "word_fasttext.model") self.char_fasttext = os.path.join(self.root, "checkpoints", "word2vec", "char_fasttext.model") self.model = FastText.load(self.char_fasttext) def sentenceTupleToEmbedding(self, data1, data2): aCutListMaxLen = max([len(list(str(sentence_a))) for sentence_a in data1]) bCutListMaxLen = max([len(list(str(sentence_a))) for sentence_a in data2]) maxLen = max(aCutListMaxLen,bCutListMaxLen) seq_len = maxLen a = self.sqence_vec(data1, seq_len) #batch_size, sqence, embedding b = self.sqence_vec(data2, seq_len) return torch.FloatTensor(a), torch.FloatTensor(b) def sqence_vec(self, data, seq_len): data_a_vec = [] for sequence_a in data: sequence_vec = [] # sequence * 128 for word_a in list(str(sequence_a)): if word_a in self.model.wv: sequence_vec.append(self.model.wv[word_a]) sequence_vec = np.array(sequence_vec) add = np.zeros((seq_len - sequence_vec.shape[0], 128)) sequenceVec = np.vstack((sequence_vec, add)) data_a_vec.append(sequenceVec) a_vec = np.array(data_a_vec) return a_vec if __name__ == '__main__': word = WordEmbedding() data1 = ("浙江杭州富阳区银湖街黄先生的外卖","浙江杭州富阳区银湖街黄先生的外卖") data2 = ("富阳区浙江富阳区银湖街道新常村","浙江杭州富阳区银湖街黄先生的外卖") a, b = word.sentenceTupleToEmbedding(data1, data2) print(a.shape) print(b)
四、运行类
run__bilstm.py
import torch import os from torch.utils.data import DataLoader from my_dataset import MyDataset import pandas as pd import numpy as np from my_bilstm import SiameseLSTM import torch.nn as nn from my_word2vec import WordEmbedding class RunBiLSTM(): def __init__(self): self.learning_rate = 0.001 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") parent_path = os.path.split(os.path.realpath(__file__))[0] self.root = parent_path[:parent_path.find("models")] # E:\personas\semantics\ self.train_path = os.path.join(self.root, "datas", "bert_data", "sim_data", "train.csv") self.val_path = os.path.join(self.root, "datas", "bert_data", "sim_data", "val.csv") self.test_path = os.path.join(self.root, "datas", "bert_data", "sim_data", "test.csv") self.batch_size =256 self.epoch = 50 self.criterion = nn.BCEWithLogitsLoss().to(self.device) self.word = WordEmbedding() self.check_point = os.path.join(self.root, "checkpoints", "char_bilstm", "char_bilstm.pth") def get_loader(self, path): data = pd.read_csv(path, sep="\t") d1, d2, y = data["s1"], data["s2"], list(data["y"]) dataset = MyDataset(d1, d2, torch.LongTensor(y)) data_iter = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True) return data_iter def binary_acc(self, preds, y): preds = torch.round(torch.sigmoid(preds)) correct = torch.eq(preds, y).float() acc = correct.sum() / len(correct) return acc def train(self, mynet, train_iter, optimizer, criterion, epoch, device): avg_acc = [] avg_loss = [] mynet.train() for batch_id, (data1, data2, label) in enumerate(train_iter): try: a, b = self.word.sentenceTupleToEmbedding(data1, data2) except Exception as e: print("错误") a, b, label = a.to(device), b.to(device), label.to(device) distence = mynet(a, b) distence = distence.squeeze(1) loss = criterion(distence, label.float()) acc = self.binary_acc(distence, label.float()).item() avg_acc.append(acc) optimizer.zero_grad() loss.backward() optimizer.step() if batch_id % 100 == 0: print("轮数:", epoch, "batch: ", batch_id, "训练损失:", loss.item(), "准确率:", acc) avg_loss.append(loss.item()) avg_acc = np.array(avg_acc).mean() avg_loss = np.array(avg_loss).mean() print('train acc:', avg_acc) print("train loss", avg_loss) def eval(self, mynet, test_iter, criteon, epoch, device): mynet.eval() avg_acc = [] avg_loss = [] with torch.no_grad(): for batch_id, (data1, data2, label) in enumerate(test_iter): try: a, b = self.word.sentenceTupleToEmbedding(data1, data2) except Exception as e: continue a, b, label = a.to(device), b.to(device), label.to(device) distence = mynet(a, b) distence = distence.squeeze(1) loss = criteon(distence, label.float()) acc = self.binary_acc(distence, label.float()).item() avg_acc.append(acc) avg_loss.append(loss.item()) if batch_id>50: break avg_acc = np.array(avg_acc).mean() avg_loss = np.array(avg_loss).mean() print('>>test acc:', avg_acc) print(">>test loss:", avg_loss) return (avg_acc, avg_loss) def run_train(self): model = SiameseLSTM(128).to(self.device) max_acc = 0 train_iter = self.get_loader(self.train_path) val_iter = self.get_loader(self.val_path) optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate) for epoch in range(self.epoch): self.train(model, train_iter, optimizer, self.criterion, epoch, self.device) eval_acc, eval_loss = self.eval(model, val_iter, self.criterion, epoch, self.device) if eval_acc > max_acc: print("save model") torch.save(model.state_dict(), self.check_point) max_acc = eval_acc def test(self): da = self.get_loader(self.val_path) for batch_id, (data1, data2, label) in enumerate(da): print(label) break if __name__ == '__main__': RunBiLSTM().run_train()
五、实验结果
轮数: 32 batch: 0 训练损失: 0.1690833866596222 准确率: 0.91796875 轮数: 32 batch: 100 训练损失: 0.16252592206001282 准确率: 0.9296875 轮数: 32 batch: 200 训练损失: 0.16619177162647247 准确率: 0.9375 轮数: 32 batch: 300 训练损失: 0.1599806845188141 准确率: 0.9453125 train acc: 0.9276657348242812 train loss 0.18327004048294915 >>test acc: 0.9079337269067764 >>test loss: 0.24136937782168388
train acc: 0.9688872803514377
train loss 0.08603085891697734
>>test acc: 0.9298221915960312
>>test loss: 0.22169270366430283