NLP(四十三):sentence_bert+pytorch向量检索,进行语义匹配

一、项目目录

 

 二、data_clean生成数据

from common.root_path import root
import os
import pandas as pd

class DataMerge(object):
    def __init__(self):
        self.data_path = os.path.join(root, "data", "raw_data")
        self.out_path = os.path.join(root, "data", "clean_data")
        self.neg_data = os.path.join(self.out_path, "negtive_data.txt")
        self.sim_path = os.path.join(self.out_path, "sim_data.txt")
        self.final = os.path.join(root, "data", "final_data")
        self.train_sim = os.path.join(self.final, "train.txt")
        self.dev_sim = os.path.join(self.final, "dev.txt")
        self.test_sim = os.path.join(self.final, "test.txt")

    def data_merge(self, role):
        out_sentence, out_label = [], []
        data_path = os.path.join(self.data_path, role)
        # 训练集数据
        train_path = os.path.join(data_path, "train.txt")
        train_t = pd.read_csv(train_path, sep="\t", header=None, names=["sentence", "label"])
        out_sentence.extend(train_t["sentence"])
        out_label.extend(train_t["label"])
        # 验证集数据
        dev_path = os.path.join(data_path, "dev.txt")
        dev_t = pd.read_csv(dev_path, sep="\t", header=None, names=["sentence", "label"])
        out_sentence.extend(dev_t["sentence"])
        out_label.extend(dev_t["label"])
        # 测试集数据
        test_path = os.path.join(data_path, "test.txt")
        test_t = pd.read_csv(test_path, sep="\t", header=None, names=["sentence", "label"])
        out_sentence.extend(test_t["sentence"])
        out_label.extend(test_t["label"])
        # 去重
        clean_sentence, clean_label = [], []
        for s,l in zip(out_sentence, out_label):
            if s not in clean_sentence:
                clean_sentence.append(s)
                clean_label.append(l)
        # 写入文件
        df = pd.DataFrame(
            {
                "sentence": clean_sentence,
                "label": clean_label,
            }
        )
        all_data_path = os.path.join(self.out_path, "all_" + role + ".txt")
        df.to_csv(all_data_path, sep="\t", index=None, header=None)
        return df

    def get_all(self):
        """坐席客户数据文件合并"""
        df_seats = self.data_merge("seats")
        df_semantic = self.data_merge("semantic")
        df = pd.concat([df_seats, df_semantic])
        all_data = os.path.join(self.out_path, "all_data.txt")
        df.to_csv(all_data, sep="\t", header=None, index=None)

    def generator_negtive_data(self):
        """数据随机打乱"""
        all_data = os.path.join(self.out_path, "all_data.txt")
        df_1 = pd.read_csv(all_data, sep="\t", header=None, names=["sentence", "label"])
        df_2 = df_1.sample(frac=1)
        out_s1, out_s2, out_l = [],[],[]
        for s1, l1, s2, l2 in zip(df_1["sentence"], df_1["label"],
                                  df_2["sentence"], df_2["label"]):
            l1 = l1.replace(" ", "").replace("\n", "").replace("\r", "")

            if l1 != l2:
                out_s1.append(s1)
                out_s2.append(s2)
                out_l.append("0")
        df = pd.DataFrame({
            "s1":out_s1,
            "s2":out_s2,
            "label": out_l
        })
        df.to_csv(self.neg_data, sep="\t", index=None)

    def generator_sim_data(self):
        out_s1, out_s2, out_label = list(), list(), list()
        all_data = os.path.join(self.out_path, "all_data.txt")
        t = pd.read_csv(all_data, sep="\t", header=None, names=["sentence", "label"])
        data_dict = dict()
        for index, row in t.iterrows():
            s = row["sentence"]
            l = row["label"]
            if l not in data_dict.keys():
                data_dict[l] = list()
            data_dict[l].append(s)
        for l, s_list in data_dict.items():
            s_list_len = len(s_list)
            for index, s in enumerate(s_list):
                if index > s_list_len -2:
                    break
                out_s1.append(s)
                out_s2.append(s_list[index + 1])
                out_label.append("1")

        df = pd.DataFrame({
            "s1": out_s1,
            "s2": out_s2,
            "label": out_label
        })
        df.to_csv(self.sim_path, index=None, sep="\t")

    def merge_sim_neg_data(self):
        df1 = pd.read_csv(self.sim_path, sep="\t")
        df2 = pd.read_csv(self.neg_data, sep="\t")
        df_all = pd.concat([df1, df2])
        df = df_all.sample(frac=1.0)
        cut_idx_1 = int(round(0.05 * df.shape[0]))
        cut_idx_2 = int(round(0.1 * df.shape[0]))
        print(cut_idx_1, cut_idx_2)
        df_test, df_dev, df_train = df.iloc[:cut_idx_1], df.iloc[cut_idx_1:cut_idx_2], df.iloc[cut_idx_2:]
        df_test.to_csv(self.test_sim, index=False, sep='\t')
        df_dev.to_csv(self.dev_sim, index=False, sep='\t')
        df_train.to_csv(self.train_sim, index=False, sep='\t')

if __name__ == '__main__':
    DataMerge().merge_sim_neg_data()

三、root_path

import os
__all__ = ["root"]
_parent_path = os.path.split(os.path.realpath(__file__))[0]
_root = _parent_path[:_parent_path.find("sentence_bert")]
root = os.path.join(_root, "sentence_bert")

四、训练

from torch.utils.data import DataLoader
import math
from sentence_transformers import SentenceTransformer,  LoggingHandler, losses, models, util
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.readers import InputExample
import logging
from datetime import datetime
import os
from common.root_path import root
import pandas as pd

class MySentenceBert():
    logging.basicConfig(format='%(asctime)s - %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        level=logging.INFO,
                        handlers=[LoggingHandler()])
    def __init__(self):

        self.train_batch_size = 16
        self.num_epochs = 4
        data_path = os.path.join(root, "data", "final_data")
        self.train_data = pd.read_csv(os.path.join(data_path, "train.txt"), sep="\t")
        self.val_data = pd.read_csv(os.path.join(data_path, "val.txt"), sep="\t")
        self.test_data = pd.read_csv(os.path.join(data_path, "test.txt"), sep="\t")
        self.model_save_path = os.path.join(root, "chkpt", "sentence_bert_model" +
                                            datetime.now().strftime("_%Y_%m_%d_%H_%M"))

    def data_generator(self):
        logging.info("generator dataset")
        train_datas = []
        dev_datas = []
        test_datas = []
        for s1, s2, l in zip(self.train_data["s1"],
                             self.train_data["s2"],
                             self.train_data["label"]):
            train_datas.append(InputExample(texts=[s1, s2], label=float(l)))
        for s1, s2, l in zip(self.val_data["s1"],
                             self.val_data["s2"],
                             self.val_data["label"]):
            dev_datas.append(InputExample(texts=[s1, s2], label=float(l)))
        for s1, s2, l in zip(self.test_data["s1"],
                             self.test_data["s2"],
                             self.test_data["label"]):
            test_datas.append(InputExample(texts=[s1, s2], label=float(l)))
        return train_datas, dev_datas, test_datas

    def train(self, train_datas, dev_datas, model):
        train_dataloader = DataLoader(train_datas, shuffle=True, batch_size=self.train_batch_size)
        train_loss = losses.CosineSimilarityLoss(model=model)
        evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_datas, name='sts-dev')
        warmup_steps = math.ceil(len(train_dataloader) * self.num_epochs  * 0.1)
        logging.info("Warmup-steps: {}".format(warmup_steps))
        model.fit(train_objectives=[(train_dataloader, train_loss)],
                  evaluator=evaluator,
                  epochs=self.num_epochs,
                  evaluation_steps=1000,
                  warmup_steps=warmup_steps,
                  output_path=self.model_save_path)
    def test(self, test_samples):
        model = SentenceTransformer(self.model_save_path)
        test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name='sts-test')
        test_evaluator(model, output_path=self.model_save_path)

    def main(self):
        train_datas, dev_datas, test_datas = self.data_generator()

        model_name = os.path.join(root, "chkpt", "distiluse-base-multilingual-cased")
        word_embedding_model = models.Transformer(model_name)
        pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                                       pooling_mode_mean_tokens=True,
                                       pooling_mode_cls_token=False,
                                       pooling_mode_max_tokens=False)
        model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
        self.train(train_datas, dev_datas, model)
        self.test(test_datas)

if __name__ == '__main__':
    MySentenceBert().main()

五、向量检索

from sentence_transformers import SentenceTransformer, util
import os
import csv
import pickle
import time
from root_path import root
import json

class SemanticSearch():
    def __init__(self):
        model_name = os.path.join(root, "chkpt", "sentence_bert_model_2021_08_05_18_16")
        self.model = SentenceTransformer(model_name)
        embedding_cache_path = 'semantic_search_embedding.pkl'
        dataset_path = os.path.join(root, "data", "bert_data", "index.txt")
        with open(os.path.join(root, "config", "code_to_label.json"), "r", encoding="utf8") as f:
            self.d = json.load(f)
        self.sentences = list()
        self.code = list()
        if not os.path.exists(embedding_cache_path):
            with open(dataset_path, encoding='utf8') as fIn:
                for read_line in fIn:
                    read_line = read_line.split("\t")
                    self.sentences.append(read_line[0])
                    self.code.append(read_line[1].replace("\n", ""))
            print("Encode the corpus. This might take a while")
            self.embeddings = self.model.encode(self.sentences, show_progress_bar=True, convert_to_tensor=True)
            print("Store file on disc")
            with open(embedding_cache_path, "wb") as fOut:
                pickle.dump({'sentences': self.sentences, 'embeddings': self.embeddings, "code": self.code}, fOut)
        else:
            print("Load pre-computed embeddings from disc")
            with open(embedding_cache_path, "rb") as fIn:
                cache_data = pickle.load(fIn)
                self.sentences = cache_data['sentences']
                self.embeddings = cache_data['embeddings']
                self.code = cache_data["code"]

    def query(self, query):
        inp_question = query
        question_embedding = self.model.encode(inp_question, convert_to_tensor=True)
        hits = util.semantic_search(question_embedding, self.embeddings)
        hit = hits[0][0]  # Get the hits for the first query
        score = hit['score']
        text = self.sentences[hit['corpus_id']]
        kh_code = self.code[hit['corpus_id']]
        label = self.d[kh_code][1]
        return label,score,text

    def main(self):
        self.query("你好")


if __name__ == '__main__':
    SemanticSearch().main()

 六、参考

https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/sts/training_stsbenchmark.py

https://github.com/UKPLab/sentence-transformers/blob/master/examples/applications/semantic-search/semantic_search_quora_pytorch.py

 

posted @ 2022-02-18 14:40  jasonzhangxianrong  阅读(924)  评论(1编辑  收藏  举报