sentence-transformers(SBert)中文文本相似度预测(附代码)
sentence-transformers(SBert)中文文本相似度预测(附代码)
https://blog.csdn.net/weixin_54218079/article/details/128687878
https://gitee.com/liheng103/sbert-evaluate
https://www.sbert.net/
训练模型
创建网络:使用Sbert官方给出的预训练模型sentence_hfl_chinese-roberta-wwm-ext,先载入embedding层进行分词,再载入池化层并传入嵌入后的维度,对模型进行降维压缩,最后载入密集层,选择Than激活函数,输出维度大小为256维。
获取训练数据:构建出新模型后使用InputExample类存储训练数据,它接受文本对字符串列表和用于指示语义相似性的标签,用标准的Pytorch Dataloader包装train_examples,作用是打乱数据并生成特定大小的批次。
计算损失函数:对于每个句子对,通过网络传递句子A和句子B,从而产生嵌入u和v,使用余弦相似度计算相似性,并将结果与标准相似度得分进行比较。这样网络就能够进行微调,更好地识别句子的相似性。
模型调优:通过调用model.fit()来调优模型。向model.fit()中传递train_objective列表(由元组(dataloader, loss_function))组成。也可以传递多个元组,以便在具有不同损失函数的多个数据集上执行多任务学习。在训练过程需要使用sentence_transformers.evaluation评估表现是否有所改善,它包含各种可以传递给fit方法的evaluators。Evaluators会在训练期间定期运行,并且会返回分数,只有得分最高的模型才会存储在磁盘上。
首先运行preprocess.py获取数据,并划分训练集和测试集,之后运行train_sentence_bert.py,使用预训练模型, sbert将数据集用sbert训练相似度任务,得到训练好的模型,最后运行evaluate.py评估训练好的模型,将结果保存在predict.txt中,并输出预测结果。
这部分在详细代码里注释得很全。
后端部分
使用flask编写post接口,接收的数据格式为application/json,将前端传来的两个句子使用训练好的模型对其进行相似度预测,将得到的相似度类型从无法序列化存入json的tensor转成list,并将状态码,信息,数据返回给前端。
from sentence_transformers import SentenceTransformer, util
# 后端接口
from flask import Flask, jsonify, request
import re
# 用当前脚本名称实例化Flask对象,方便flask从该脚本文件中获取需要的内容
app = Flask(__name__)
# 使通过jsonify返回的中文显示正常,否则显示为ASCII码
app.config["JSON_AS_ASCII"] = False
model_path = 'D:/xxx模型路径/'
model = SentenceTransformer(model_path)
@app.route("/evaluate",methods=['POST'])
def evalute_sentence():
s1 = request.json.get("s1")
s2 = request.json.get("s2")
if s1 and s2:
embedding1 = model.encode(s1, convert_to_tensor=True)
embedding2 = model.encode(s2, convert_to_tensor=True)
similarity = util.cos_sim(embedding1, embedding2).tolist()
return jsonify({"code": 200, "msg": "预测成功", "data": similarity})
else:
return jsonify({"code": 400, "msg": "缺少字段"})
if __name__ == '__main__':
app.run(debug=True)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
前端部分
————————————————
版权声明:本文为CSDN博主「我先润了」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/weixin_54218079/article/details/128687878