通过撰写代码理解向量计算
embeded模型基于m3e。
一、原生向量代码,自己计算距离
import numpy as np from numpy import dot from numpy.linalg import norm from sentence_transformers import SentenceTransformer model = SentenceTransformer('/home/helu/milvus/m3e-base') ###functions && classes#### def cos_sim(a, b): '''余弦距离 -- 越大越相似''' return dot(a, b)/(norm(a)*norm(b)) def l2(a, b): '''欧式距离 -- 越小越相似''' x = np.asarray(a)-np.asarray(b) return norm(x) ###需要换成本地接口### def get_embeddings(texts): #data = embedding.create(input=texts).data embeddings = model.encode(texts) #return [x.embedding for x in data] return embeddings test_query = ["测试文本"] vec = get_embeddings(test_query)[0] print(vec[:10]) print(len(vec)) #query = "体育" # 且能支持跨语言 query = "sports" documents = [ "联合国就苏丹达尔富尔地区大规模暴力事件发出警告", "土耳其、芬兰、瑞典与北约代表将继续就瑞典“入约”问题进行谈判", "日本岐阜市陆上自卫队射击场内发生枪击事件 3人受伤", "国家游泳中心(水立方):恢复游泳、嬉水乐园等水上项目运营", "我国首次在空间站开展舱外辐射生物学暴露实验", ] query_vec = get_embeddings([query])[0] doc_vecs = get_embeddings(documents) print("Cosine distance:") print(cos_sim(query_vec, query_vec)) for vec in doc_vecs: print(cos_sim(query_vec, vec)) print("\nEuclidean distance:") print(l2(query_vec, query_vec)) for vec in doc_vecs: print(l2(query_vec, vec)) #基于以上结果,按照cos/l2方法建一个mix模型 print("mix distance:") for vec in doc_vecs: print(cos_sim(query_vec, vec)/l2(query_vec, vec))
二、引入向量检索工具Faiss,帮助计算距离
import numpy as np import faiss from numpy import dot from numpy.linalg import norm from sentence_transformers import SentenceTransformer model = SentenceTransformer('/home/helu/milvus/m3e-base') ###functions && classes#### def get_datas_embedding(datas): return model.encode(datas) # 构建索引,FlatL2为例 def create_index(datas_embedding): index = faiss.IndexFlatL2(datas_embedding.shape[1]) # 这里必须传入一个向量的维度,创建一个空的索引 index.add(datas_embedding) # 把向量数据加入索引 return index # 查询索引 def data_recall(faiss_index, query, top_k): query_embedding = model.encode([query]) Distance, Index = faiss_index.search(query_embedding, top_k) return Index ############################### #query = "体育" # 且能支持跨语言 query = "sports" documents = [ "联合国就苏丹达尔富尔地区大规模暴力事件发出警告", "土耳其、芬兰、瑞典与北约代表将继续就瑞典“入约”问题进行谈判", "日本岐阜市陆上自卫队射击场内发生枪击事件 3人受伤", "国家游泳中心(水立方):恢复游泳、嬉水乐园等水上项目运营", "我国首次在空间站开展舱外辐射生物学暴露实验", ]
datas_embedding = get_datas_embedding(documents)
faiss_index = create_index(datas_embedding)
sim_data_Index = data_recall(faiss_index,query, 3)
print("相似的top3数据是:")
for index in sim_data_Index[0]:
print(documents[int(index)] + "\n")