python利用Trie(前缀树)实现搜索引擎中关键字输入提示(学习Hash Trie和Double-array Trie)

python利用Trie(前缀树)实现搜索引擎中关键字输入提示(学习Hash Trie和Double-array Trie)

 主要包括两部分内容:
(1)利用python中的dict实现Trie;
(2)按照darts-java的方法做python的实现Double-array Trie

比较:
(1)的实现相对简单,但在词典较大时,时间复杂度较高
(2)Double-array Trie是Trie高效实现,时间复杂度达到O(n),但是实现相对较难

 

最近遇到一个问题,希望对地名检索时,根据用户的输入,实时推荐用户可能检索的候选地名,并根据实时热度进行排序。这可以以归纳为一个Trie(前缀树)问题。



Trie在自然语言处理中非常常用,可以实现文本的快速分词、词频统计、字符串查询和模糊匹配、字符串排序、关键输入提示、关键字纠错等场景中。

这些问题都可以在单词树/前缀树/Trie来解决,关于Trie的介绍看【小白详解 Trie 树】这篇文章就够了

一、Hash实现Trie(python中的dict)

   github上有Trie实现关键字,实现Trie树的新增、删除、查找,并根据热度CACHED_THREHOLD在node节点对后缀进行缓存,以便提高对高频词的检索效率。本人在其代码上做了注解。
   并对其进行了测试,测试的数据包括了两列,包括关键词和频次。
【code】

#!/usr/bin/env python
# encoding: utf-8
"""
@date:    20131001
@version: 0.2
@author:  wklken@yeah.net
@desc:    搜索下拉提示,基于后台提供数据,建立数据结构(前缀树),用户输入query前缀时,可以提示对应query前缀补全

@update:
    20131001 基本结构,新增,搜索等基本功能
    20131005 增加缓存功能,当缓存打开,用户搜索某个前缀超过一定次数时,进行缓存,减少搜索时间
    20140309 修改代码,降低内存占用

@TODO:
    test case
    加入拼音的话,导致内存占用翻倍增长,要考虑下如何优化节点,共用内存

"""
#这是实现cache的一种方式,也可以使用redis/memcached在外部做缓存
#https://github.com/wklken/suggestion/blob/master/easymap/suggest.py
#一旦打开,search时会对每个节点做cache,当增加删除节点时,其路径上的cache会被清除,搜索时间降低了一个数量级
#代价:内存消耗, 不需要时可以关闭,或者通过CACHED_THREHOLD调整缓存数量

#开启
#CACHED = True
#关闭
CACHED = False

#注意,CACHED_SIZE >= search中的limit,保证search从缓存能获取到足够多的结果
CACHED_SIZE = 10
#被搜索超过多少次后才加入缓存
CACHED_THREHOLD = 10


############### start ######################

class Node(dict):
    def __init__(self, key, is_leaf=False, weight=0, kwargs=None):
        """
        @param key: 节点字符
        @param is_leaf: 是否叶子节点
        @param weight: 节点权重, 某个词最后一个字节点代表其权重,其余中间节点权重为0,无意义
        @param kwargs: 可传入其他任意参数,用于某些特殊用途
        """
        self.key = key
        self.is_leaf = is_leaf
        self.weight = weight

        #缓存,存的是node指针
        self.cache = []
        #节点前缀搜索次数,可以用于搜索query数据分析
        self.search_count = 0

        #其他节点无关仅和内容相关的参数
        if kwargs:
            for key, value in kwargs.iteritems():
                setattr(self, key, value)


    def __str__(self):
        return '<Node key:%s is_leaf:%s weight:%s Subnodes: %s>' % (self.key, self.is_leaf, self.weight, self.items())


    def add_subnode(self, node):
        """
        添加子节点

        :param node: 子节点对象
        """
        self.update({node.key: node})


    def get_subnode(self, key):
        """
        获取子节点

        :param key: 子节点key
        :return: Node对象
        """
        return self.get(key)


    def has_subnode(self):
        """
        判断是否存在子节点

        :return: bool
        """
        return len(self) > 0


    def get_top_node(self, prefix):
        """
        获取一个前缀的最后一个节点(补全所有后缀的顶部节点)

        :param prefix: 字符转前缀
        :return: Node对象
        """
        top = self

        for k in prefix:
            top = top.get_subnode(k)
            if top is None:
                return None
        return top


def depth_walk(node):
    """
    递归,深度优先遍历一个节点,返回每个节点所代表的key以及所有关键字节点(叶节点)

    @param node: Node对象
    """
    result = []
    if node.is_leaf:
        #result.append(('', node))    
        if len(node) >0:#修改,避免该前缀刚好是关键字时搜索不到
            result.append((node.key[:-1], node))
            node.is_leaf=False
            depth_walk(node)
        else:
            return [('', node)]

    if node.has_subnode():
        for k in node.iterkeys():
            s = depth_walk(node.get(k))
            #print k , s[0][0]
            result.extend([(k + subkey, snode) for subkey, snode in s])
        return result
    #else:
        #print node.key
        #return [('', node)]


def search(node, prefix, limit=None, is_case_sensitive=False):
    """
    搜索一个前缀下的所有单词列表 递归

    @param node: 根节点
    @param prefix: 前缀
    @param limit: 返回提示的数量
    @param is_case_sensitive: 是否大小写敏感
    @return: [(key, node)], 包含提示关键字和对应叶子节点的元组列表
    """
    if not is_case_sensitive:
        prefix = prefix.lower()

    node = node.get_top_node(prefix)
    #print 'len(node):' ,len(node)
    #如果找不到前缀节点,代表匹配失败,返回空
    if node is None:
        return []

    #搜索次数递增
    node.search_count += 1

    if CACHED and node.cache:
        return node.cache[:limit] if limit is not None else node.cache
    #print depth_walk(node)
    result = [(prefix + subkey, pnode) for subkey, pnode in depth_walk(node)]

    result.sort(key=lambda x: x[1].weight, reverse=True)

    if CACHED and node.search_count >= CACHED_THREHOLD:
        node.cache = result[:CACHED_SIZE]
    #print len(result)
    return result[:limit] if limit is not None else result

#TODO: 做成可以传递任意参数的,不需要每次都改    2013-10-13 done
def add(node, keyword, weight=0, **kwargs):
    """
    加入一个单词到树

    @param node: 根节点
    @param keyword: 关键词,前缀
    @param weight: 权重
    @param kwargs: 其他任意存储属性
    """
    one_node = node

    index = 0
    last_index = len(keyword) - 1
    for c in keyword:
        if c not in one_node:
            if index != last_index:
                one_node.add_subnode(Node(c, weight=weight))
            else:
                one_node.add_subnode(Node(c, is_leaf=True, weight=weight, kwargs=kwargs))
            one_node = one_node.get_subnode(c) 
        else:
            one_node = one_node.get_subnode(c)

            if CACHED:
                one_node.cache = []

            if index == last_index:
                one_node.is_leaf = True
                one_node.weight = weight
                for key, value in kwargs:
                    setattr(one_node, key, value)
        index += 1

def delete(node, keyword, judge_leaf=False):
    """
    从树中删除一个单词

    @param node: 根节点
    @param keyword: 关键词,前缀
    @param judge_leaf: 是否判定叶节点,递归用,外部调用使用默认值
    """

    # 空关键词,传入参数有问题,或者递归调用到了根节点,直接返回
    if not keyword:
        return

    top_node = node.get_top_node(keyword)
    if top_node is None:
        return

    #清理缓存
    if CACHED:
        top_node.cache = []

    #递归往上,遇到节点是某个关键词节点时,要退出
    if judge_leaf:
        if top_node.is_leaf:
            return
    #非递归,调用delete
    else:
        if not top_node.is_leaf:
            return

    if top_node.has_subnode():
        #存在子节点,去除其标志 done
        top_node.is_leaf = False
        return
    else:
        #不存在子节点,逐层检查删除节点
        this_node = top_node

        prefix = keyword[:-1]
        top_node = node.get_top_node(prefix)
        del top_node[this_node.key]
        delete(node, prefix, judge_leaf=True)


##############################
#  增补功能 读数据文件建立树 #
##############################

def build(file_path, is_case_sensitive=False):
    """
    从文件构建数据结构, 文件必须utf-8编码,可变更

    @param file_path: 数据文件路径,数据文件默认两列,格式“关键词\t权重"
    @param is_case_sensitive: 是否大小写敏感
    """
    node = Node("")
    f = open(file_path)
    for line in f:
        line = line.strip()
        if not isinstance(line,unicode):
            line = line.decode('utf-8')
        parts = line.split('\t')
        name = parts[0]
        if not is_case_sensitive:
            name = name.lower()
        add(node, name, int(parts[1]))
    f.close()
    return node

import time
if __name__ == '__main__':
    #print '============ test1 ==============='
    #n = Node("")
    #default weight=0, 后面的参数可以任意加,搜索返回结果再从node中将放入对应的值取出,这里放入一个othervalue值
    #add(n, u'he', othervalue="v-he")
    #add(n, u'her', weight=0, othervalue="v-her")
    #add(n, u'hero', weight=10, othervalue="v-hero")
    #add(n, u'hera', weight=3, othervalue="v-hera")

    #delete(n, u'hero')

    #print "search h: "
    #for key, node in search(n, u'h'):
        #print key, node, node.othervalue, id(node)
        #print key, node.weight

    #print "serch her: "
    #for key, node in search(n, u'her'):
        #print key, node, node.othervalue, id(node)
        #print key, node.weight
    start= time.clock()
    print '============ test2 ==============='
    tree = build("./shanxinpoi.txt", is_case_sensitive=False)
    print len(tree),'time:',time.clock()-start
    startline=time.clock()
    print u'search 秦岭'
    for key, node in search(tree, u'秦岭', limit=10):
        print key, node.weight
    print time.clock()-startline

二、Trie的Double-array Trie实现

Trie的Double-array Trie的实现参考【小白详解 Trie 树】和【双数组Trie树(DoubleArrayTrie)Java实现】

在看代码之前提醒几点:
(1)Comero有根据komiya-atsushi/darts-java,进行了Double-array Trie的python实现,komiya-atsushi的实现巧妙使用了文字的的编码,以文字的编码(一个汉字三个字符,每个字符0-256)作为【小白详解 Trie 树】中的字符编码。

(2)代码中不需要构造真正的Trie树,直接用字符串,构造对应node,因为words是排过序的,这样避免Trie树在构建过程中频繁从根节点开始重构

(3)实现中使用了了base[s]+c=t & check[t]=base[s],而非【小白详解 Trie 树】中的base[s]+c=t & check[t]=s

(4)komiya-atsushi实现Trie的构建、从词典文件创建,以及对构建Trie的本地化(保存base和check,下次打开不用再重新构建)

(5)本文就改了Comero中的bug,并对代码进行了注解。并参照dingyaguang117/DoubleArrayTrie(java)中的代码实现了输入提示FindAllWords方法。

(6)本文实现的FindAllWords输入提示方法没有用到词频信息,但是实现也不难

【code】

# -*- coding:utf-8 -*-

# base
# https://linux.thai.net/~thep/datrie/datrie.html
# http://jorbe.sinaapp.com/2014/05/11/datrie/
# http://www.hankcs.com/program/java/%E5%8F%8C%E6%95%B0%E7%BB%84trie%E6%A0%91doublearraytriejava%E5%AE%9E%E7%8E%B0.html 
# (komiya-atsushi/darts-java | 先建立Trie树,再构造DAT,为siblings先找到合适的空间)
# https://blog.csdn.net/kissmile/article/details/47417277
# http://nark.cc/p/?p=1480
#https://github.com/midnight2104/midnight2104.github.io/blob/58b5664b3e16968dd24ac5b1b3f99dc21133b8c4/_posts/2018-8-8-%E5%8F%8C%E6%95%B0%E7%BB%84Trie%E6%A0%91(DoubleArrayTrie).md

# 不需要构造真正的Trie树,直接用字符串,构造对应node,因为words是排过序的
# todo : error info
# todo : performance test
# todo : resize
# warning: code=0表示叶子节点可能会有隐患(正常词汇的情况下是ok的)
# 修正: 由于想要回溯字符串的效果,叶子节点和base不能重合(这样叶子节点可以继续记录其他值比如频率),叶子节点code: 0->-1
# 但是如此的话,叶子节点可能会与正常节点冲突? 找begin的使用应该是考虑到的?
#from __future__ import print_function
class DATrie(object):

    class Node(object):

        def __init__(self, code, depth, left, right):
            self.code = code
            self.depth = depth
            self.left = left
            self.right = right

    def __init__(self):
        self.MAX_SIZE = 2097152  # 65536 * 32
        self.base = [0] * self.MAX_SIZE
        self.check = [-1] * self.MAX_SIZE  # -1 表示空
        self.used = [False] * self.MAX_SIZE
        self.nextCheckPos = 0  # 详细 见后面->当数组某段使用率达到某个值时记录下可用点,以便下次不再使用
        self.size = 0  # 记录总共用到的空间

    # 需要改变size的时候调用,这里只能用于build之前。cuz没有打算复制数据.
    def resize(self, size):
        self.MAX_SIZE = size
        self.base = [0] * self.MAX_SIZE
        self.check = [-1] * self.MAX_SIZE
        self.used = [False] * self.MAX_SIZE

    # 先决条件是self.words ordered 且没有重复
    # siblings至少会有一个
    def fetch(self, parent):   ###获取parent的孩子,存放在siblings中,并记录下其左右截至
        depth = parent.depth

        siblings = []  # size == parent.right-parent.left
        i = parent.left 
        while i < parent.right: #遍历所有子节点,right-left+1个单词
            s = self.words[i][depth:]  #词的后半部分
            if s == '':
                siblings.append(
                    self.Node(code=-1, depth=depth+1, left=i, right=i+1)) # 叶子节点
            else:
                c = ord(s[0])  #字符串中每个汉字占用3个字符(code,实际也就当成符码),将每个字符转为数字 ,树实际是用这些数字构建的
                #print type(s[0]),c
                if siblings == [] or siblings[-1].code != c:
                    siblings.append(
                        self.Node(code=c, depth=depth+1, left=i, right=i+1)) # 新建节点
                else:  # siblings[-1].code == c
                    siblings[-1].right += 1   #已经是排过序的可以直接计数+1
            i += 1
        # siblings
        return siblings


    # 在insert之前,认为可以先排序词汇,对base的分配检查应该是有利的
    # 先构建树,再构建DAT,再销毁树
    def build(self, words):
        words = sorted(list(set(words)))  # 去重排序
        #for word in words:print word.decode('utf-8')
        self.words = words
        # todo: 销毁_root
        _root = self.Node(code=0, depth=0, left=0, right=len(self.words))  #增加第一个节点
        self.base[0] = 1
        siblings = self.fetch(_root)
        #for ii in  words: print ii.decode('utf-8')
        #print 'siblings len',len(siblings)
        #for i in siblings: print i.code
        self.insert(siblings, 0)  #插入根节点的第一层孩子
        # while False:  # 利用队列来实现非递归构造
            # pass
        del self.words
        print("DATrie builded.")


    def insert(self, siblings, parent_base_idx):
        """ parent_base_idx为父节点base index, siblings为其子节点们 """
        # 暂时按komiya-atsushi/darts-java的方案
        # 总的来讲是从0开始分配beigin]
        #self.used[parent_base_idx] = True

        begin = 0
        pos = max(siblings[0].code + 1, self.nextCheckPos) - 1 #从第一个孩子的字符码位置开始找,因为排过序,前面的都已经使用
        nonzero_num = 0  # 非零统计
        first = 0  

        begin_ok_flag = False  # 找合适的begin
        while not begin_ok_flag:
            pos += 1
            if pos >= self.MAX_SIZE:
                raise Exception("no room, may be resize it.")
            if self.check[pos] != -1 or self.used[pos]:   # check——check数组,used——占用标记,表明pos位置已经占用
                nonzero_num += 1  # 已被使用
                continue
            elif first == 0:
                self.nextCheckPos = pos  # 第一个可以使用的位置,记录?仅执行一遍
                first = 1

            begin = pos - siblings[0].code  # 第一个孩子节点对应的begin

            if begin + siblings[-1].code >= self.MAX_SIZE:
                raise Exception("no room, may be resize it.")

            if self.used[begin]:    #该位置已经占用
                continue

            if len(siblings) == 1:  #只有一个节点
                begin_ok_flag = True
                break

            for sibling in siblings[1:]:
                if self.check[begin + sibling.code] == -1 and self.used[begin + sibling.code] is False: #对于sibling,begin位置可用
                    begin_ok_flag = True
                else:
                    begin_ok_flag = False  #用一个不可用,则begin不可用
                    break

        # 得到合适的begin

        # -- Simple heuristics --
        # if the percentage of non-empty contents in check between the
        # index 'next_check_pos' and 'check' is greater than some constant value
        # (e.g. 0.9), new 'next_check_pos' index is written by 'check'.
        
        #从位置 next_check_pos 开始到 pos 间,如果已占用的空间在95%以上,下次插入节点时,直接从 pos 位置处开始查找成功获得这一层节点的begin之后得到,影响下一次执行insert时的查找效率
        if (nonzero_num / (pos - self.nextCheckPos + 1)) >= 0.95:
            self.nextCheckPos = pos

        self.used[begin] = True

        # base[begin] 记录 parent chr  -- 这样就可以从节点回溯得到字符串 
        # 想要可以回溯的话,就不能在字符串末尾节点记录值了,或者给叶子节点找个0以外的值? 0->-1
        #self.base[begin] = parent_base_idx     #【*】
        #print 'begin:',begin,self.base[begin]

        if self.size < begin + siblings[-1].code + 1:
            self.size = begin + siblings[-1].code + 1
        
        for sibling in siblings: #更新所有子节点的check     base[s]+c=t & check[t]=s
            self.check[begin + sibling.code] = begin

        for sibling in siblings:  # 由于是递归的情况,需要先处理完check
            # darts-java 还考虑到叶子节点有值的情况,暂时不考虑(需要记录的话,记录在叶子节点上)
            if sibling.code == -1:
                self.base[begin + sibling.code] = -1 * sibling.left - 1
            else:
                new_sibings = self.fetch(sibling)
                h = self.insert(new_sibings, begin + sibling.code) #插入孙子节点,begin + sibling.code为子节点的位置
                self.base[begin + sibling.code] = h #更新base所有子节点位置的转移基数为[其孩子最合适的begin]

        return begin


    def search(self, word):
        """ 查找单词是否存在 """
        p = 0  # root
        if word == '':
            return False
        for c in word:
            c = ord(c)
            next = abs(self.base[p]) + c
            # print(c, next, self.base[next], self.check[next])
            if next > self.MAX_SIZE:  # 一定不存在
                return False
            # print(self.base[self.base[p]])
            if self.check[next] != abs(self.base[p]):
                return False
            p = next
        
        # print('*'*10+'\n', 0, p, self.base[self.base[p]], self.check[self.base[p]])
        # 由于code=0,实际上是base[leaf_node->base+leaf_node.code],这个负的值本身没什么用
        # 修正:left code = -1
        if self.base[self.base[p] - 1] < 0 and self.base[p] == self.check[self.base[p] - 1] :  
            #print p
            return True
        else:  # 不是词尾
            return False


    def common_prefix_search(self, content):
        """ 公共前缀匹配 """
        # 用了 darts-java 写法,再仔细看一下
        result = []
        b = self.base[0]  # 从root开始
        p = 0
        n = 0
        tmp_str = ""
        for c in content:
            c = ord(c)
            p = b
            n = self.base[p - 1]      # for iden leaf

            if b == self.check[p - 1] and n < 0:
                result.append(tmp_str)

            tmp_str += chr(c)
            #print(tmp_str )
            p = b + c   # cur node
            
            if b == self.check[p]:
                b = self.base[p]  # next base
            else:                 # no next node
                return result

        # 判断最后一个node
        p = b
        n = self.base[p - 1]

        if b == self.check[p - 1] and n < 0:
            result.append(tmp_str)

        return result

    def Find_Last_Base_index(self, word):
        b = self.base[0]  # 从root开始
        p = 0
        #n = 0
        #print len(word)
        tmp_str = ""
        for c in word:
            c = ord(c)
            p = b
            p = b + c   # cur node, p is new base position, b is the old
            
            if b == self.check[p]:
                tmp_str += chr(c)
                b = self.base[p]  # next base
            else:                 # no next node
                return -1
        #print '====', p, self.base[p], tmp_str.decode('utf-8')
        return p

    def GetAllChildWord(self,index):
        result = []
        #result.append("")
       # print self.base[self.base[index]-1],'++++'
        if self.base[self.base[index]-1] <= 0 and self.base[index] == self.check[self.base[index] - 1]:
            result.append("")
            #return result
        for i in range(0,256):
            #print(chr(i))
            if self.check[self.base[index]+i]==self.base[index]:
                #print self.base[index],(chr(i)),i
                for s in self.GetAllChildWord(self.base[index]+i):
                    #print s
                    result.append( chr(i)+s)
        return result

    def FindAllWords(self, word):
        result = []
        last_index=self.Find_Last_Base_index(word)
        if last_index==-1:
            return result
        for end in self.GetAllChildWord(last_index):
            result.append(word+end)
        return result

    def get_string(self, chr_id):
        """ 从某个节点返回整个字符串, todo:改为私有 """
        if self.check[chr_id] == -1:
            raise Exception("不存在该字符。")
        child = chr_id
        s = []
        while 0 != child:
            base = self.check[child]
            print(base, child)
            label = chr(child - base)
            s.append(label)
            print(label)
            child = self.base[base]
        return "".join(s[::-1])


    def get_use_rate(self):
        """ 空间使用率 """
        return self.size / self.MAX_SIZE

if __name__ == '__main__':
    words = ["一举","一举一动",'11',
            "一举成名",
            "一举成名天下知","洛阳市西工区中州中路","人民东路2号","中州东",
            "洛阳市","洛阳","洛神1","洛神赋","万科","万达3","万科翡翠","万达广场",
            "洛川","洛川苹果","商洛","商洛市","商朝","商业","商业模","商业模式",
            "万能",
            "万能胶"]

    #for word in words:print [word]  #一个汉字的占用3个字符,
    words=[]
    for line in open('1000.txt').readlines():
    #    #print line.strip().decode('utf-8')
        words.append(line.strip())
    
    datrie = DATrie()
    datrie.build(words)
    #for line in open('1000.txt').readlines():
    #    print(datrie.search(line.strip()),end=' ')
    #print('-'*10)
    #print(datrie.search("景华路"))
    #print('-'*10)
    #print(datrie.search("景华路号"))
    
    # print('-'*10)
    #for item in datrie.common_prefix_search("商业模式"): print(item.decode('utf-8'))
    #for item in datrie.common_prefix_search("商业模式"):print item.decode('utf-8')
    # print(datrie.common_prefix_search("一举成名天下知"))
    #print(datrie.base[:1000])
    # print('-'*10)
    # print(datrie.get_string(21520))
    #index=datrie.Find_Last_Base_index("商业")
    #print(index),'-=-=-='
    #print datrie.search("商业"),datrie.search("商业"),datrie.search("商业模式")
    #print index, datrie.check[datrie.base[index]+230],datrie.base[index]
    for ii in  datrie.FindAllWords('中州中路'):print ii.decode('utf-8')
    #print(datrie.Find_Last_Base_index("一举")[2].decode('utf-8'))
#print()

测试数据是洛阳地址1000.txt

最后欢迎参与讨论。



参考:


小白详解Trie树:https://segmentfault.com/a/1190000008877595

Hash实现Trie(python中的dict)(源码):https://github.com/wklken/suggestion/blob/master/easymap/suggest.py

双数组Trie树(DoubleArrayTrie)Java实现(主要理解):http://www.hankcs.com/program/java/%E5%8F%8C%E6%95%B0%E7%BB%84trie%E6%A0%91doublearraytriejava%E5%AE%9E%E7%8E%B0.html

Comero对DoubleArrayTrie的python实现(源码):https://github.com/helmz/toy_algorithms_in_python/blob/master/double_array_trie.py

DoubleArrayTrie树的Tail压缩,java实现(源码):https://github.com/dingyaguang117/DoubleArrayTrie/blob/master/src/DoubleArrayTrie.java#L348

搜索时的动态提示:https://mp.weixin.qq.com/s/fT2LJ-skNEdh89DnH9FRxw

posted on 2018-12-11 14:06  米仓山下  阅读(3792)  评论(5编辑  收藏  举报

导航