多线程提速

对于请求反馈使用线程来提速

"""
    Function: get similarity query
    Author: dengyx
    DateTime: 20201019
"""
import jieba
import time
import tqdm
import threading
import queue
import numpy as np
from gensim.models import KeyedVectors
import logging
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
from utils.soaring_vector.soaring_vector.soaring_vector import SoaringVectorClient, IndexType, FieldType, IndexData, SearchQuery, Vector, FloatVector

client = SoaringVectorClient("172.16.24.150", 8098, 1000)
print("health : ", client.health())

index_name = "seo-query-v10dim"
if client.exist(index_name):
    api_index = client.get(index_name)
    print(index_name + " is exist")
else:
    schema = {'query': FieldType.STRING_TYPE, 'id': FieldType.STRING_TYPE}
    api_index = client.create(index_name, "query search match", IndexType.FloatFlatIP, 'query', schema, 10, thread=12)
    client.set_alias(index_name, "seo-phrase-match")

print(api_index.info)


class QuerySimilarity(object):
    def __init__(self,):
        # self.query_path = r'data/test.txt'
        self.query_path = r'data/seo_search_word_copy.txt'
        self.w2c_path = r'resources_10dim/word2vec.model'
        self.query_features = r'resources/features.pkl'
        self.tables = r'resources/hashtables.pkl'
        self.table_num = 3
        self.Hashcode_fun = 6
        self.query2id = {}
        self.thread_num = 8

        print('加载词向量...')
        t1 = time.time()
        self.model = KeyedVectors.load(self.w2c_path, mmap='r')
        t2 = time.time()
        print('词向量加载时间:{:.2f}s'.format(t2-t1))
        with open(self.query_path, 'r', encoding='utf8') as fr:
            self.content = fr.readlines()

        for each in self.content:
            item = each.strip().split('\t')
            query_id = item[0]
            query = item[-1]
            self.query2id[query] = query_id

    def cosine_sim(self, x, y):
        num = x.dot(y.T)
        denom = np.linalg.norm(x) * np.linalg.norm(y)
        return num / denom

    def feature_extract(self, query):
        """ word -> feature
        :param query:
        :return:
        """
        vec = []
        tokens = jieba.lcut(query)
        for word in tokens:
            if word in self.model:
                vec.append(self.model[word])
            else:
                vec.append([0]*10)
                # print('{}\n{}\n{} not in word2vec'.format(query, tokens, word))
        vec = np.array(vec)
        mean_vec = np.mean(vec, axis=0)
        if len(mean_vec) != 10:
            print('向量纬度不是100')
        return mean_vec

    def upload_data(self):
        """ clean segment stopwords
        :return:
        """
        self.counter = 0
        # self.query2id = {}
        data_map_buffer = dict()
        for each in self.content:
            item = each.strip().split('\t')
            query_id = item[0]
            query = item[-1]
            # self.query2id[query] = query_id
            current_feature = self.feature_extract(query)
            vector = self.l2_norm(current_feature).tolist()
            data = {'query': query, 'id': query_id}
            data_map_buffer[query] = IndexData(data, vector)
            if len(data_map_buffer) > 1000:
                api_index.put(data_map_buffer)
                self.counter += len(data_map_buffer)
                data_map_buffer = dict()
                logging.info('put ' + str(self.counter))
        if len(data_map_buffer) > 0:
            api_index.put(data_map_buffer)
            self.counter += len(data_map_buffer)
            logging.info('put ' + str(self.counter))
            data_map_buffer = dict()
        print('数据上传完成')

    def l2_norm(self, m):
        dist = np.sqrt((m ** 2).sum(-1))[..., np.newaxis]
        m /= dist
        return m

    def download(self):
        with open(self.query_path, 'r', encoding='utf8') as fr:
            content = fr.readlines()
            new_content = []
            for each in tqdm.tqdm(content):
                each_item = each.strip().split('\t')
                phrase = each_item[-1]

                api_vector = dict(api_index.get(phrase).data.vector.vector).get(phrase).floatVector.values
                query = SearchQuery(vector=Vector(floatVector=FloatVector(values=api_vector)))
                res = api_index.search(query, 0, 40)
                line = ''
                for ret in res.result:
                    items = sorted(ret.item, key=lambda v: v.score, reverse=True)
                    for item in items[1:31]:
                        line += self.query2id[item.key] + ''
                to_save = each.strip() + '\t' + line[:-1] + '\n'
                new_content.append(to_save)

        save_path = r'data/query_top30_20201021.txt'
        with open(save_path, 'w', encoding='utf8') as fw:
            fw.writelines(new_content)
        print('数据保存成功:{}'.format(save_path))

    def run(self, q, fw):
        while True:
            if q.empty():
                return
            else:
                sample = q.get()
                each_item = sample.strip().split('\t')
                phrase = each_item[-1]
                api_vector = dict(api_index.get(phrase).data.vector.vector).get(phrase).floatVector.values
                query = SearchQuery(vector=Vector(floatVector=FloatVector(values=api_vector)))
                res = api_index.search(query, 0, 40)
                line = ''
                # result = []
                for ret in res.result:
                    items = sorted(ret.item, key=lambda v: v.score, reverse=True)
                    for item in items[1:31]:
                        line += self.query2id[item.key] + ''
                        # result.append(item.key)
                to_save = sample.strip() + '\t' + line[:-1] + '\n'
                # print(result)
                # print(to_save)
                print(each_item[0])
                fw.write(to_save)

    def main(self, data_path):
        q = queue.Queue()
        save_path = r'data/query_top30_20201022.txt'
        fw = open(save_path, 'a', encoding='utf8')

        # split_num = 250000
        # with open(self.query_path, 'r', encoding='utf8') as fr:
        #     content = fr.readlines()
        #     for i in range(0, len(content), split_num):
        #         split_data = content[i:i+split_num]
        #         with open('data/split_data/group_{}.txt'.format(i), 'w', encoding='utf8') as fw:
        #             fw.writelines(split_data)

        with open(data_path, 'r', encoding='utf8') as fr:
            content = fr.readlines()
            for d in tqdm.tqdm(content):
                q.put(d)
            print('数据放入队列完毕')
        t1 = time.time()
        threads = []
        print('数据预测中...')
        for i in range(self.thread_num):
            t = threading.Thread(target=self.run, args=(q, fw))
            threads.append(t)
        for i in range(self.thread_num):
            threads[i].start()
        for i in range(self.thread_num):
            threads[i].join()
        t2 = time.time()
        print('处理速度:{:.4f}sample/s'.format(len(content)/(t2-t1)))
        print('数据写入完毕')


# Press the green button in the gutter to run the script.
if __name__ == '__main__':
    data_path = r'data/seo_search_word_copy.txt'
    qs = QuerySimilarity()
    qs.main(data_path)
    # qs.upload_data()

 

posted @ 2020-10-22 12:35  今夜无风  阅读(242)  评论(0编辑  收藏  举报