NLP(四十三):sentence_bert+pytorch向量检索,进行语义匹配
一、项目目录
二、data_clean生成数据
from common.root_path import root import os import pandas as pd class DataMerge(object): def __init__(self): self.data_path = os.path.join(root, "data", "raw_data") self.out_path = os.path.join(root, "data", "clean_data") self.neg_data = os.path.join(self.out_path, "negtive_data.txt") self.sim_path = os.path.join(self.out_path, "sim_data.txt") self.final = os.path.join(root, "data", "final_data") self.train_sim = os.path.join(self.final, "train.txt") self.dev_sim = os.path.join(self.final, "dev.txt") self.test_sim = os.path.join(self.final, "test.txt") def data_merge(self, role): out_sentence, out_label = [], [] data_path = os.path.join(self.data_path, role) # 训练集数据 train_path = os.path.join(data_path, "train.txt") train_t = pd.read_csv(train_path, sep="\t", header=None, names=["sentence", "label"]) out_sentence.extend(train_t["sentence"]) out_label.extend(train_t["label"]) # 验证集数据 dev_path = os.path.join(data_path, "dev.txt") dev_t = pd.read_csv(dev_path, sep="\t", header=None, names=["sentence", "label"]) out_sentence.extend(dev_t["sentence"]) out_label.extend(dev_t["label"]) # 测试集数据 test_path = os.path.join(data_path, "test.txt") test_t = pd.read_csv(test_path, sep="\t", header=None, names=["sentence", "label"]) out_sentence.extend(test_t["sentence"]) out_label.extend(test_t["label"]) # 去重 clean_sentence, clean_label = [], [] for s,l in zip(out_sentence, out_label): if s not in clean_sentence: clean_sentence.append(s) clean_label.append(l) # 写入文件 df = pd.DataFrame( { "sentence": clean_sentence, "label": clean_label, } ) all_data_path = os.path.join(self.out_path, "all_" + role + ".txt") df.to_csv(all_data_path, sep="\t", index=None, header=None) return df def get_all(self): """坐席客户数据文件合并""" df_seats = self.data_merge("seats") df_semantic = self.data_merge("semantic") df = pd.concat([df_seats, df_semantic]) all_data = os.path.join(self.out_path, "all_data.txt") df.to_csv(all_data, sep="\t", header=None, index=None) def generator_negtive_data(self): """数据随机打乱""" all_data = os.path.join(self.out_path, "all_data.txt") df_1 = pd.read_csv(all_data, sep="\t", header=None, names=["sentence", "label"]) df_2 = df_1.sample(frac=1) out_s1, out_s2, out_l = [],[],[] for s1, l1, s2, l2 in zip(df_1["sentence"], df_1["label"], df_2["sentence"], df_2["label"]): l1 = l1.replace(" ", "").replace("\n", "").replace("\r", "") if l1 != l2: out_s1.append(s1) out_s2.append(s2) out_l.append("0") df = pd.DataFrame({ "s1":out_s1, "s2":out_s2, "label": out_l }) df.to_csv(self.neg_data, sep="\t", index=None) def generator_sim_data(self): out_s1, out_s2, out_label = list(), list(), list() all_data = os.path.join(self.out_path, "all_data.txt") t = pd.read_csv(all_data, sep="\t", header=None, names=["sentence", "label"]) data_dict = dict() for index, row in t.iterrows(): s = row["sentence"] l = row["label"] if l not in data_dict.keys(): data_dict[l] = list() data_dict[l].append(s) for l, s_list in data_dict.items(): s_list_len = len(s_list) for index, s in enumerate(s_list): if index > s_list_len -2: break out_s1.append(s) out_s2.append(s_list[index + 1]) out_label.append("1") df = pd.DataFrame({ "s1": out_s1, "s2": out_s2, "label": out_label }) df.to_csv(self.sim_path, index=None, sep="\t") def merge_sim_neg_data(self): df1 = pd.read_csv(self.sim_path, sep="\t") df2 = pd.read_csv(self.neg_data, sep="\t") df_all = pd.concat([df1, df2]) df = df_all.sample(frac=1.0) cut_idx_1 = int(round(0.05 * df.shape[0])) cut_idx_2 = int(round(0.1 * df.shape[0])) print(cut_idx_1, cut_idx_2) df_test, df_dev, df_train = df.iloc[:cut_idx_1], df.iloc[cut_idx_1:cut_idx_2], df.iloc[cut_idx_2:] df_test.to_csv(self.test_sim, index=False, sep='\t') df_dev.to_csv(self.dev_sim, index=False, sep='\t') df_train.to_csv(self.train_sim, index=False, sep='\t') if __name__ == '__main__': DataMerge().merge_sim_neg_data()
三、root_path
import os __all__ = ["root"] _parent_path = os.path.split(os.path.realpath(__file__))[0] _root = _parent_path[:_parent_path.find("sentence_bert")] root = os.path.join(_root, "sentence_bert")
四、训练
from torch.utils.data import DataLoader import math from sentence_transformers import SentenceTransformer, LoggingHandler, losses, models, util from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator from sentence_transformers.readers import InputExample import logging from datetime import datetime import os from common.root_path import root import pandas as pd class MySentenceBert(): logging.basicConfig(format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO, handlers=[LoggingHandler()]) def __init__(self): self.train_batch_size = 16 self.num_epochs = 4 data_path = os.path.join(root, "data", "final_data") self.train_data = pd.read_csv(os.path.join(data_path, "train.txt"), sep="\t") self.val_data = pd.read_csv(os.path.join(data_path, "val.txt"), sep="\t") self.test_data = pd.read_csv(os.path.join(data_path, "test.txt"), sep="\t") self.model_save_path = os.path.join(root, "chkpt", "sentence_bert_model" + datetime.now().strftime("_%Y_%m_%d_%H_%M")) def data_generator(self): logging.info("generator dataset") train_datas = [] dev_datas = [] test_datas = [] for s1, s2, l in zip(self.train_data["s1"], self.train_data["s2"], self.train_data["label"]): train_datas.append(InputExample(texts=[s1, s2], label=float(l))) for s1, s2, l in zip(self.val_data["s1"], self.val_data["s2"], self.val_data["label"]): dev_datas.append(InputExample(texts=[s1, s2], label=float(l))) for s1, s2, l in zip(self.test_data["s1"], self.test_data["s2"], self.test_data["label"]): test_datas.append(InputExample(texts=[s1, s2], label=float(l))) return train_datas, dev_datas, test_datas def train(self, train_datas, dev_datas, model): train_dataloader = DataLoader(train_datas, shuffle=True, batch_size=self.train_batch_size) train_loss = losses.CosineSimilarityLoss(model=model) evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_datas, name='sts-dev') warmup_steps = math.ceil(len(train_dataloader) * self.num_epochs * 0.1) logging.info("Warmup-steps: {}".format(warmup_steps)) model.fit(train_objectives=[(train_dataloader, train_loss)], evaluator=evaluator, epochs=self.num_epochs, evaluation_steps=1000, warmup_steps=warmup_steps, output_path=self.model_save_path) def test(self, test_samples): model = SentenceTransformer(self.model_save_path) test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name='sts-test') test_evaluator(model, output_path=self.model_save_path) def main(self): train_datas, dev_datas, test_datas = self.data_generator() model_name = os.path.join(root, "chkpt", "distiluse-base-multilingual-cased") word_embedding_model = models.Transformer(model_name) pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode_mean_tokens=True, pooling_mode_cls_token=False, pooling_mode_max_tokens=False) model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) self.train(train_datas, dev_datas, model) self.test(test_datas) if __name__ == '__main__': MySentenceBert().main()
五、向量检索
from sentence_transformers import SentenceTransformer, util import os import csv import pickle import time from root_path import root import json class SemanticSearch(): def __init__(self): model_name = os.path.join(root, "chkpt", "sentence_bert_model_2021_08_05_18_16") self.model = SentenceTransformer(model_name) embedding_cache_path = 'semantic_search_embedding.pkl' dataset_path = os.path.join(root, "data", "bert_data", "index.txt") with open(os.path.join(root, "config", "code_to_label.json"), "r", encoding="utf8") as f: self.d = json.load(f) self.sentences = list() self.code = list() if not os.path.exists(embedding_cache_path): with open(dataset_path, encoding='utf8') as fIn: for read_line in fIn: read_line = read_line.split("\t") self.sentences.append(read_line[0]) self.code.append(read_line[1].replace("\n", "")) print("Encode the corpus. This might take a while") self.embeddings = self.model.encode(self.sentences, show_progress_bar=True, convert_to_tensor=True) print("Store file on disc") with open(embedding_cache_path, "wb") as fOut: pickle.dump({'sentences': self.sentences, 'embeddings': self.embeddings, "code": self.code}, fOut) else: print("Load pre-computed embeddings from disc") with open(embedding_cache_path, "rb") as fIn: cache_data = pickle.load(fIn) self.sentences = cache_data['sentences'] self.embeddings = cache_data['embeddings'] self.code = cache_data["code"] def query(self, query): inp_question = query question_embedding = self.model.encode(inp_question, convert_to_tensor=True) hits = util.semantic_search(question_embedding, self.embeddings) hit = hits[0][0] # Get the hits for the first query score = hit['score'] text = self.sentences[hit['corpus_id']] kh_code = self.code[hit['corpus_id']] label = self.d[kh_code][1] return label,score,text def main(self): self.query("你好") if __name__ == '__main__': SemanticSearch().main()
六、参考
https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/sts/training_stsbenchmark.py
https://github.com/UKPLab/sentence-transformers/blob/master/examples/applications/semantic-search/semantic_search_quora_pytorch.py