多线程提速
对于请求反馈使用线程来提速
""" Function: get similarity query Author: dengyx DateTime: 20201019 """ import jieba import time import tqdm import threading import queue import numpy as np from gensim.models import KeyedVectors import logging logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) from utils.soaring_vector.soaring_vector.soaring_vector import SoaringVectorClient, IndexType, FieldType, IndexData, SearchQuery, Vector, FloatVector client = SoaringVectorClient("172.16.24.150", 8098, 1000) print("health : ", client.health()) index_name = "seo-query-v10dim" if client.exist(index_name): api_index = client.get(index_name) print(index_name + " is exist") else: schema = {'query': FieldType.STRING_TYPE, 'id': FieldType.STRING_TYPE} api_index = client.create(index_name, "query search match", IndexType.FloatFlatIP, 'query', schema, 10, thread=12) client.set_alias(index_name, "seo-phrase-match") print(api_index.info) class QuerySimilarity(object): def __init__(self,): # self.query_path = r'data/test.txt' self.query_path = r'data/seo_search_word_copy.txt' self.w2c_path = r'resources_10dim/word2vec.model' self.query_features = r'resources/features.pkl' self.tables = r'resources/hashtables.pkl' self.table_num = 3 self.Hashcode_fun = 6 self.query2id = {} self.thread_num = 8 print('加载词向量...') t1 = time.time() self.model = KeyedVectors.load(self.w2c_path, mmap='r') t2 = time.time() print('词向量加载时间:{:.2f}s'.format(t2-t1)) with open(self.query_path, 'r', encoding='utf8') as fr: self.content = fr.readlines() for each in self.content: item = each.strip().split('\t') query_id = item[0] query = item[-1] self.query2id[query] = query_id def cosine_sim(self, x, y): num = x.dot(y.T) denom = np.linalg.norm(x) * np.linalg.norm(y) return num / denom def feature_extract(self, query): """ word -> feature :param query: :return: """ vec = [] tokens = jieba.lcut(query) for word in tokens: if word in self.model: vec.append(self.model[word]) else: vec.append([0]*10) # print('{}\n{}\n{} not in word2vec'.format(query, tokens, word)) vec = np.array(vec) mean_vec = np.mean(vec, axis=0) if len(mean_vec) != 10: print('向量纬度不是100') return mean_vec def upload_data(self): """ clean segment stopwords :return: """ self.counter = 0 # self.query2id = {} data_map_buffer = dict() for each in self.content: item = each.strip().split('\t') query_id = item[0] query = item[-1] # self.query2id[query] = query_id current_feature = self.feature_extract(query) vector = self.l2_norm(current_feature).tolist() data = {'query': query, 'id': query_id} data_map_buffer[query] = IndexData(data, vector) if len(data_map_buffer) > 1000: api_index.put(data_map_buffer) self.counter += len(data_map_buffer) data_map_buffer = dict() logging.info('put ' + str(self.counter)) if len(data_map_buffer) > 0: api_index.put(data_map_buffer) self.counter += len(data_map_buffer) logging.info('put ' + str(self.counter)) data_map_buffer = dict() print('数据上传完成') def l2_norm(self, m): dist = np.sqrt((m ** 2).sum(-1))[..., np.newaxis] m /= dist return m def download(self): with open(self.query_path, 'r', encoding='utf8') as fr: content = fr.readlines() new_content = [] for each in tqdm.tqdm(content): each_item = each.strip().split('\t') phrase = each_item[-1] api_vector = dict(api_index.get(phrase).data.vector.vector).get(phrase).floatVector.values query = SearchQuery(vector=Vector(floatVector=FloatVector(values=api_vector))) res = api_index.search(query, 0, 40) line = '' for ret in res.result: items = sorted(ret.item, key=lambda v: v.score, reverse=True) for item in items[1:31]: line += self.query2id[item.key] + ',' to_save = each.strip() + '\t' + line[:-1] + '\n' new_content.append(to_save) save_path = r'data/query_top30_20201021.txt' with open(save_path, 'w', encoding='utf8') as fw: fw.writelines(new_content) print('数据保存成功:{}'.format(save_path)) def run(self, q, fw): while True: if q.empty(): return else: sample = q.get() each_item = sample.strip().split('\t') phrase = each_item[-1] api_vector = dict(api_index.get(phrase).data.vector.vector).get(phrase).floatVector.values query = SearchQuery(vector=Vector(floatVector=FloatVector(values=api_vector))) res = api_index.search(query, 0, 40) line = '' # result = [] for ret in res.result: items = sorted(ret.item, key=lambda v: v.score, reverse=True) for item in items[1:31]: line += self.query2id[item.key] + ',' # result.append(item.key) to_save = sample.strip() + '\t' + line[:-1] + '\n' # print(result) # print(to_save) print(each_item[0]) fw.write(to_save) def main(self, data_path): q = queue.Queue() save_path = r'data/query_top30_20201022.txt' fw = open(save_path, 'a', encoding='utf8') # split_num = 250000 # with open(self.query_path, 'r', encoding='utf8') as fr: # content = fr.readlines() # for i in range(0, len(content), split_num): # split_data = content[i:i+split_num] # with open('data/split_data/group_{}.txt'.format(i), 'w', encoding='utf8') as fw: # fw.writelines(split_data) with open(data_path, 'r', encoding='utf8') as fr: content = fr.readlines() for d in tqdm.tqdm(content): q.put(d) print('数据放入队列完毕') t1 = time.time() threads = [] print('数据预测中...') for i in range(self.thread_num): t = threading.Thread(target=self.run, args=(q, fw)) threads.append(t) for i in range(self.thread_num): threads[i].start() for i in range(self.thread_num): threads[i].join() t2 = time.time() print('处理速度:{:.4f}sample/s'.format(len(content)/(t2-t1))) print('数据写入完毕') # Press the green button in the gutter to run the script. if __name__ == '__main__': data_path = r'data/seo_search_word_copy.txt' qs = QuerySimilarity() qs.main(data_path) # qs.upload_data()
时刻记着自己要成为什么样的人!