Trie

Trie,又称单词查找树,Trie 树,是一种树形结构,是一种哈希树的变种。典型应
用是用于统计,排序和保存大量的字符串(但不仅限于字符串),所以经常被搜索
引擎系统用于文本词频统计。它的优点是:利用字符串的公共前缀来减少查询时
间,最大限度地减少无谓的字符串比较,查询效率比哈希树高。 ——百度百科


Trie 树是一种性能优异的哈希树,是一种常用的树状结构,常用于字符串保存、查找等操作,由于其在树上查找字符串的方式与查字典相似,所以常被称为字典树。它使用不同字符串的公共前缀减少查询时间和存储空间,减少字符串比较,可以快速的在 \(O(n)\) 的时间内于任意多的字符串中保存或是查找字符串。

想到我们查英语字典的时候,我们对于想查的单词(如 \(animal\)),我们会先在整本词典中查找它的第一位字母 \(a\),再在其第一位字母\(a\)以下的区域查找第二位字母 \(n\) 所在位置……再在第五位字母 \(a\) 以下的区域查找最后一位字母 \(l\),就可以在词典几千个单词中查找单词长度 \(6\) 次找出需要查询的单词信息。

这就是 Trie 查找的原理,写入的方式也比较相似。

如果我们将一些的字符串拉成链,全部挂在一个点上,那么就可以形成一棵庞大的树。然后若我们将能合并的结点都合并(如 \(abcd\)\(abce\) 同时挂在根节点下,我们可以考虑将两序列都拥有的 \(abc\) 三节点合并,在 \(c\) 结点下挂 \(d\)\(e\) 两子节点),最后我们可以让这课庞大的树的占用空间大幅度下降,并且可以像上述查字典一样的方式一层一层向下查找来找到任意一个开始时放入的字符串。这种改进后的数据结构就是Trie

如果你没听懂,没有关系,请看下面的解释。

在最开始的时候我们现在图中找到一个起点作为树的起点(这里记作 \(0\) 号点)。

pict-1

如果现在我们加入第一个单词 \(he\),就该单词拉成一条链挂在 \(0\) 号点下。

pict-2

再加入 \(she\),由于 \(she\)\(he\) 没有共同前缀,所以 \(she\) 的处理方法与 \(he\) 相同。

pict-3

如果加入 \(hi\),从根节点开始向下查找,发现根节点已拥有 \(h\) 结点作为孩子,那么通过 \(h\) 向下查找,发现 \(h\) 并没有 \(i\) 子节点,所以在 \(h\) 下面挂上一个 \(i\) 节点,那么 \(hi\) 就与 \(he\) 共用一个 \(h\) 的前缀,如下图所示。

pict-4

再插入 \(sha\)\(sad\) 作示范,方法与上面相同

pict-5

插入时只会注意前缀相同的部分,后面即使有相同的字母也不会产生影响

可以发现,从根节点开始(不包括根节点),任意选择一条通往叶子节点的路径,路径上经过的字符来连起来可以组成输入的一个字符串。同时每一个节点不可能出现两个拥有相同字符的孩子节点,且每个字符串在树上只有一种表达方式。


细节

如果此时我需要插入一个 \(her\),树会变成这样:

pict-6

那么如何确定这个树上到底有没有 \(he\) 这个单词呢?

添加结束标记

我们对每一个点加入一个布尔标记,记录其是否为单词结尾。为 \(True\) 表示从根节点到这里的路径表示的是一个单词,如果为 \(False\) 表示这不是一个单词。

下图将被记为单词结尾的结点标记成红色。那么树就变成这样的了。

pict-7

这样我们就可以区分出树上的每一个字符串了。


代码

接下来通过一些代码来讲解 Trie 上执行操作的方法。

定义

struct node
{
    bool tail;
    int visit;
    int child[26];
};
std::vector<node> trie;

这里的结构体代表的是 Trie 上每一个结点的类型。这里的 \(tail\) 存放的布尔类型表示该节点为几个字符串的结尾,\(visit\) 表示该节点表示的字符串被访问过几次(如果该节点表示的字符串多于一个,那么 \(visit\) 将成倍增加),\(child[c]\) 表示该节点的 \(c\) 孩子的数组下标(如果不存在该子节点,指向\(0\))。最后的 \(vector\) 就是 Trie 树的表达方式了,用 \(vector\) 存放结点可以更有效地节省空间。

加入字符串

void add(std::string s)
{
    int p=0;
    int size=s.size();
    for(int sp=0;sp<size;sp++)
    {
        int c=s[sp]-'a';
        if(trie[p].child[c]==0)
        {
            trie.push_back({});
            trie[p].child[c]=trie.size()-1;
        }
        p=trie[p].child[c];
    }
    trie[p].tail++;
    return;
}

形参\(s\)就是我们需要加入到 Trie 中的字符串,我们用 \(sp\) 遍历字符串,对于每一个 \(s[sp]\) 都会有一个 \(s[sp+1]\) 的子节点。我们从根节点开始向下搜索,如果当前结点具有我们正在匹配的这位字符,则遍历到对应子节点,否则新建子节点,并沿该新建节点继续向下遍历直至整个字符串的字符全部匹配完,在最后一个节点上将结尾标记加 \(1\)

查找字符串

int find(std::string s)
{
    int point=0;
    int sp=0;
    int size=s.size();
    while(sp<size)
    {
        int c=s[sp]-'a';
        if(trie[point].child[c]!=0)
        {
            point=trie[point].child[c];
            trie[point].visit+=trie[point].tail;
            sp++;
        }
        else return -1;
    }
    if(trie[point].tail==0) return -1;
    return trie[point].visit;
}

从根节点开始向下匹配字符串,在根节点子节点中找出 \(s[1]\) 对应的子节点,再沿该子节点向下找出 \(s[2]\) 对应的子节点……直至匹配完,返回最后的节点上的访问值(如果最后的这个结点并非一个单词的结尾,返回 \(-1\) 表示没有找到该字符串)。如果中途发现在某一处匹配子节点失败,则返回 \(-1\) 表示没有找到这个串。


代码背景

P2580

#include <iostream>
#include <vector>
#include <string>
#include <queue>
#include <map>

struct node
{
    int tail;
    int visit;
    int child[26];
};
std::vector<node> trie;

void add(std::string s)
{
    int p=0;
    int size=s.size();
    for(int sp=0;sp<size;sp++)
    {
        int c=s[sp]-'a';
        if(trie[p].child[c]==0)
        {
            trie.push_back({});
            trie[p].child[c]=trie.size()-1;
        }
        p=trie[p].child[c];
    }
    trie[p].tail++;
    return;
}

int find(std::string s)
{
    int point=0;
    int sp=0;
    int size=s.size();
    while(sp<size)
    {
        int c=s[sp]-'a';
        if(trie[point].child[c]!=0)
        {
            point=trie[point].child[c];
            trie[point].visit+=trie[point].tail;
            sp++;
        }
        else return -1;
    }
    if(trie[point].tail==0) return -1;
    return trie[point].visit;
}

int cost[305];

int main()
{
    std::ios::sync_with_stdio(false);
    int n,m;
    std::cin>>n;
    trie.resize(1);
    for(int i=0;i<n;i++)
    {
        std::string s;
        std::cin>>s;
        add(s);
    }
    std::cin>>m;
    for(int i=0;i<m;i++)
    {
        std::string s;
        std::cin>>s;
        int t=find(s);
        if(t==-1) printf("WRONG\n");
        else if(t==1) printf("OK\n");
        else printf("REPEAT\n");
    }
    return 0;
}
posted @ 2020-01-20 20:40  Macesuted  阅读(1651)  评论(0编辑  收藏  举报