算法笔记--字典树(trie 树)&& ac自动机 && 可持久化trie

字典树

简介:字典树,又称单词查找树,Trie树,是一种树形结构,是哈希树的变种。

优点:利用字符串的公共前缀来减少查询时间,最大限度地减少无谓的字符串比较。

性质:根节点不包含字符,除根节点外每一个节点都只包含一个字符; 从根节点到某一节点,路径上经过的字符连接起来,为该节点对应的字符串; 每个节点的所有子节点包含的字符都不相同。

操作:

记trie[i][j]表示第i个节点的第j个儿子为哪个节点,tot为总的节点个数

插入:

void insert() {
    int len = strlen(s);
    int rt = 0;
    for (int i = 0; i < len; i++) {
        int id = s[i] - 'a';
        if(!trie[rt][id])trie[rt][id] = ++tot;
        rt = trie[rt][id];
        // sum[rt]++;
    }
    //vis[rt] = true;
}

查询:

①如果查询的是某个单词是否出现,可以开一个bool类型的数组vis[i]表示第i个节点是否为单词的结尾

②如果查询的是某个前缀出现了多少次,可以开一个int类型的数组sum[i]表示以第i个节点为结尾的前缀出现的次数

int find() {
    int len = strlen(s);
    int rt = 0;
    for (int i = 0; i < len; i++) {
        int id = s[i] - 'a';
        if(!trie[rt][id]) return 0;
        rt = trie[rt][id];
    }
    return sum[rt];//或者vis[rt]
}

清空:每次新开一个节点,将他的儿子都清空(对应多组,每组加起来不超过某个值的情况,不能每次都memset)。

指针版:

struct node {
    int cnt;
    node *next[26];
    node() {
        cnt = 0;
        memset(next, 0, sizeof(next));
    }
};
node *rt;
void init() {
    rt = new node();
}
void Insert() {
    int len = strlen(s);
    node *p = rt;
    for (int i = 0; i < len; i++) {
        int id = s[i] - 'a';
        if(p -> next[id] == NULL) p -> next[id] = new node();
        p = p -> next[id];
        p -> cnt ++;
    }
}
int Find() {
    int len = strlen(s);
    node *p = rt;
    for (int i = 0; i < len; i++) {
        int id = s[i] - 'a';
        if(p -> next[id] == NULL) return 0;
        p = p -> next[id];
    }
    return p -> cnt;
}

例题1:hdu 1251

代码:

#include<bits/stdc++.h>
using namespace std;
#define LL long long
#define pb push_back
#define mem(a, b) memset(a, b, sizeof(a))

const int N = 5e5 + 5;
int trie[N][26], sum[N], tot = 0;
char s[55];
void Insert(int len) {
    int rt = 0;
    for (int i = 0; i < len; i++) {
        int id = s[i] - 'a';
        if(!trie[rt][id])trie[rt][id] = ++tot;
        rt = trie[rt][id];
        sum[rt]++;
    }
}
int Find(int len) {
    int rt = 0;
    for (int i = 0; i < len; i++) {
        int id = s[i] - 'a';
        //cout << s[i] << id <<endl;
        if(!trie[rt][id]) return 0;
        rt = trie[rt][id];
    }
    return sum[rt];
}
int main() {
    int tot = 0;
    while(~ scanf("%c", &s[tot++])) {
        if(s[tot - 1] == '\n') {
            if(tot == 1) break;
            else {
                Insert(tot - 1);
                tot = 0;
            }
        }
    }
    tot = 0;
    while(~ scanf("%c", &s[tot++])) {
        if(s[tot - 1] == '\n') {
            printf("%d\n", Find(tot - 1));
            tot = 0;
        }
    }
    return 0;
}
View Code
#include<bits/stdc++.h>
using namespace std;
#define LL long long
#define pb push_back
#define mem(a,b) memset(a,b,sizeof(a))

const int N=5e5+5;
int trie[N][26];
int sum[N];
int tot=0;
char s[15];
void insert(){
    int len=strlen(s);
    int root=0;
    for(int i=0;i<len;i++){
        int id=s[i]-'a';
        if(!trie[root][id])trie[root][id]=++tot;
        root=trie[root][id];
        sum[root]++;
    }
}
int find(){
    int len=strlen(s);
    int root=0;
    for(int i=0;i<len;i++){
        int id=s[i]-'a';
        if(!trie[root][id])return 0;
        root=trie[root][id];
    }
    return sum[root];
}
int main(){
    while(gets(s)!=NULL){
        if(s[0]=='\0')break;
        insert();
    }
    while(gets(s)!=NULL){
        printf("%d\n",find());
    }
    return 0;
}
View Code

例题2:SPOJ Phone List

代码:

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb push_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pii pair<int, int>
#define piii pair<pii, int>
#define mem(a, b) memset(a, b, sizeof(a))
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define fopen freopen("in.txt", "r", stdin);freopen("out.txt", "w", stout);
//head

const int N = 1e4 + 5;
int trie[N*10][10], sum[N*10], tot = 0;
string s[N];
bool insert(string s) {
    int rt = 0;
    bool f = true, ff = false;
    for (int i = 0; i < s.size(); i++) {
        int id = s[i] - '0';
        if(!trie[rt][id]) trie[rt][id] = ++tot, f = false;
        rt = trie[rt][id];
        if(sum[rt]) ff = true;
    }
    sum[rt] ++;
    return f||ff;
}
int main() {
    fio;
    int T, n;
    cin >> T;
    while(T--) {
        cin >> n;
        for (int i = 1; i <= n; i++) cin >> s[i];
        bool f = false;
        mem(trie, 0);
        mem(sum, 0);
        tot = 0;
        for (int i = 1; i <= n; i++) {
            if(insert(s[i])) f = true;
            if(f) break;
        }
        if(f) printf("NO\n");
        else printf("YES\n");
    }
    return 0;
}
View Code

例题3:Codeforces 948 D - Perfect Security 

思路:构建01字典树求最小异或值

代码:

#include<bits/stdc++.h>
using namespace std;
#define LL long long
#define pb push_back
#define mem(a, b) memset(a, b, sizeof(a))

const int N = 3e5 + 5;
int a[N], p[N], tot = 0;
int trie[N*30][2], cnt[N*30];
void insert(int x) {
    int rt = 0;
    for (int i = 30; i >= 0; i--) {
        int id = (x>>i)&1;
        if(!trie[rt][id])trie[rt][id] = ++tot;
        cnt[trie[rt][id]]++;
        rt = trie[rt][id];
    }
}
int find(int x) {
    int rt = 0, res = 0;
    for (int i = 30; i >= 0; i--) {
        int id = (x>>i)&1;
        if(cnt[trie[rt][id]] >= 1) {
            cnt[trie[rt][id]] --;
            if(id) res += 1<<i;
            rt = trie[rt][id];
        }
        else {
            cnt[trie[rt][!id]] --;
            if(!id) res += 1<<i;
            rt = trie[rt][!id];
        }
    }
    return res;
}
int main() {
    int n;
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &p[i]);
        insert(p[i]);
    }
    for (int i = 1; i <= n; i++) {
        printf("%d%c", find(a[i])^a[i], " \n"[i==n]);
    }
    return 0;
}
View Code

例题4: HDU 4825 Xor Sum

思路:构建01字典树求最大异或值

代码:

#include<bits/stdc++.h>
using namespace std;
#define LL long long
#define pb push_back
#define mem(a, b) memset(a, b, sizeof(a))

const int N = 1e5 + 5;
int trie[N*32][2];
int a[N], tot = 0;
void insert(int x) {
    int rt = 0;
    for (int i = 30; i >= 0; i--) {
        int id = (x>>i)&1;
        if(!trie[rt][id]) trie[rt][id] = ++tot;
        rt = trie[rt][id];
    }
}
int find(int x){
    int rt = 0, res = 0;
    for (int i = 30; i >= 0; i--) {
        int id = (x>>i)&1;
        if(trie[rt][!id])res += 1<<i, rt = trie[rt][!id];
        else rt = trie[rt][id];
    }
    return res;
}
int main() {
    int T, n, m, t;
    scanf("%d", &T);
    for (int i = 1; i <= T; i++) {
        printf("Case #%d:\n", i);
        scanf("%d%d", &n, &m);
        mem(trie, 0);
        tot = 0;
        for (int i = 1; i <= n; i++) scanf("%d", &a[i]), insert(a[i]);
        while(m--) scanf("%d", &t), printf("%d\n", find(t)^t);
    }
    return 0;
}
View Code

例题5:Codeforces 706 D - Vasiliy's Multiset

思路:构建01字典树求最大异或值

代码:

#include<bits/stdc++.h>
using namespace std;
#define LL long long
#define pb push_back
#define mem(a, b) memset(a, b, sizeof(a))

const int N = 2e5 + 5;
int trie[N*31][2], cnt[N*31], tot = 0;
void insert(int x) {
    int rt = 0;
    for (int i = 30; i >= 0; i--) {
        int id = (x>>i)&1;
        if(!trie[rt][id])trie[rt][id] = ++tot;
        cnt[trie[rt][id]]++;
        rt = trie[rt][id];
    }
}
void delet(int x) {
    int rt = 0;
    for (int i = 30; i >= 0; i--) {
        int id = (x>>i)&1;
        cnt[trie[rt][id]]--;
        rt = trie[rt][id];
    }
}
int find(int x) {
    int rt = 0, res = 0;
    for (int i = 30; i >= 0; i--) {
        int id = (x>>i)&1;
        if(cnt[trie[rt][!id]] >= 1) {
            res += 1<<i;
            rt = trie[rt][!id];
        }
        else rt = trie[rt][id];
    }
    return res;
}
int main() {
    int q, t;
    scanf("%d", &q);
    char s[5];
    insert(0);
    while(q--) {
        scanf("%s %d", s, &t);
        if(s[0] == '+') {
            insert(t);
        }
        else if(s[0] == '-') {
            delet(t);
        }
        else if(s[0] == '?') {
            printf("%d\n", find(t));
        }
    }
    return 0;
}
View Code

参考:

http://www.cnblogs.com/TheRoadToTheGold/p/6290732.html

ac自动机

KMP 是单模式串字符串匹配,ac自动机是多模式串字符串匹配

ac自动机是建立在字典树的基础上的,采用的是KMP的思想,

与KMP不同的是,ac自动机是构建失配指针fail来实现跳跃的,fail指针指向的是相同后缀的上一个模式串的位置。

1.建树与字典树相同

2.构建失配指针fail

在求当前节点 u 的失配指针时,保证深度比当前节点低的节点的fail指针都已经被求出来了,

假设当前节点的父亲为p, p连向u的边的编号是i

那么看tree[fail[p]][i]存在否?

如果存在,那么fail[u] = tree[fail[p]][i]

如果不存在,那么看tree[fail[fail[p]]][i]存在否?以此类推,一直往上找,直到找到根节点结束。(在用bfs构建失配指针时,这一步可以采用递推路径压缩优化成O(1),不用一直往上找。)

3.查找,一开始查找到当前位置的最大后缀,统计模式串的个数,然后跳到上一个后缀,统计个数,以此类推,直到为空为止。

模板:

struct AC_automation {
    int tree[N][26], tot = 0;
    int cnt[N];
    int fail[N];
    void init() {
        for (int i = 0; i <= tot; i++) {
            for (int j = 0; j < 26; j++) {
                tree[i][j] = 0;
            }
            cnt[i] = 0;
            fail[i] = 0;
        }
        tot = 0;
    }
    void insert(char s[]) {
        int rt = 0;
        for (int i = 0; s[i]; i++) {
            int id = s[i]-'a';
            if(!tree[rt][id]) tree[rt][id] = ++tot;
            rt = tree[rt][id];
        }
        cnt[rt]++;
    }
    void build() {
        queue<int> q;
        for (int i = 0; i < 26; i++) if(tree[0][i]) q.push(tree[0][i]);
        while(!q.empty()) {
            int u = q.front();
            q.pop();
            for (int i = 0; i < 26; i++) {
                if(tree[u][i]) {
                    fail[tree[u][i]] = tree[fail[u]][i];
                    q.push(tree[u][i]);
                }
                else tree[u][i] = tree[fail[u]][i];
            }
        }
    }
    int query(char t[]) {
        int rt = 0, res = 0;
        for (int i = 0; t[i]; i++) {
            int id = t[i] - 'a';
            rt = tree[rt][id];
            for (int j = rt; j && ~cnt[j]; j = fail[j]) res += cnt[j], cnt[j] = -1;
        }
        return res;
    }
};

参考:https://www.luogu.org/blog/42196/qiang-shi-tu-xie-ac-zi-dong-ji

可持久化trie

可持久线段树差不多

01可持久化trie模板:

//在old版本上修改
void update(int old, int &rt, int v) {
    rt = ++tot;
    int now = rt;
    for (int i = 24; i >= 0; --i) {
        if(v&(1<<i)) trie[now][0] = trie[old][0], trie[now][1] = ++tot;
        else trie[now][1] = trie[old][1], trie[now][0] = ++tot;
        now = trie[now][(v>>i)&1];
        old = trie[old][(v>>i)&1];
        sz[now] = sz[old] + 1;
    }
}
//查询区间中某个数和x亦或最大
int query(int l, int r, int x) {
    int res = 0;
    for (int i = 24; i >= 0; --i) {
        int id = (x>>i)&1;
        if(sz[trie[r][!id]]-sz[trie[l][!id]]) l = trie[l][!id], r = trie[r][!id], res |= 1<<i;
        else l = trie[l][id], r = trie[r][id];
    }
    return res;
}
//查询区间中某个数和x亦或第k小 
int query(int l, int r, int k, int x) {
    int ans = 0;
    for (int i = 24; i >= 0; --i) {
        int id = (x>>i)&1;
        int num = cnt[trie[r][id]] - cnt[trie[l][id]];
        if(k <= num) l = trie[l][id], r = trie[r][id];
        else ans |= 1<<i,  l = trie[l][id^1], r = trie[r][id^1], k -= num;
    }
    return ans;
}

3261: 最大异或和

注意root[0]要添加一个值为0的

代码:

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<bits/stdc++.h>
using namespace std;
#define y1 y11
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb push_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pli pair<LL, int>
#define pii pair<int, int>
#define piii pair<pii, int>
#define pdd pair<double, double>
#define mem(a, b) memset(a, b, sizeof(a))
#define debug(x) cerr << #x << " = " << x << "\n";
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//head

const int N = 3e5 + 5, M = 2e7 + 5;
int trie[M][2], sz[M], root[N*2], a[N*2], tot = 0, n, m, l, r, x;
char op[10];
void update(int old, int &rt, int v) {
    rt = ++tot;
    int now = rt;
    for (int i = 24; i >= 0; --i) {
        if(v&(1<<i)) trie[now][0] = trie[old][0], trie[now][1] = ++tot;
        else trie[now][1] = trie[old][1], trie[now][0] = ++tot;
        now = trie[now][(v>>i)&1];
        old = trie[old][(v>>i)&1];
        sz[now] = sz[old] + 1;
    }
}
int query(int l, int r, int x) {
    int res = 0;
    for (int i = 24; i >= 0; --i) {
        int id = (x>>i)&1;
        if(sz[trie[r][!id]]-sz[trie[l][!id]]) l = trie[l][!id], r = trie[r][!id], res |= 1<<i;
        else l = trie[l][id], r = trie[r][id];
    }
    return res;
}
int main() {
    scanf("%d %d", &n, &m);
    for (int i = 1; i <= n; ++i) scanf("%d", &a[i]), a[i] ^= a[i-1];
    update(0, root[0], 0);
    for (int i = 1; i <= n; ++i) update(root[i-1], root[i], a[i]);
    for (int i = 1; i <= m; ++i) {
        scanf("%s", op);
        if(op[0] == 'A') {
            scanf("%d", &x);
            a[++n] = x;
            a[n] ^= a[n-1];
            update(root[n-1], root[n], a[n]);
        }
        else {
            scanf("%d %d %d", &l, &r, &x);
            l--, r--;
            if(l) printf("%d\n", query(root[l-1], root[r], x^a[n]));
            else  printf("%d\n", query(0, root[r], x^a[n]));
        }
    }
    return 0;
}
View Code

PS:可持久化字典树还可以用求区间字典序第k小的字符串,和主席树差不多的方法 

HZNUOJ Little Sub and Sequence

思路:求第k大,如果只有Xor,那么建可持久化trie树,查询时根据Xor这一位来决定先往左儿子走还是先往右儿子走

这道题的话将将Or和And的影响转移到Xor上面,转移方法:

Or产生影响的时候是某一位为1,这时候将所有数字这一位强制变为0,Xor这位变成1,这样亦或以后就是1

And产生影响的时候是某一位为0,这时候将所有数字这一位强制变为0,Xor这位变成0,这样亦或以后就是0

所以每一位最多被强制变成0一次,所以强制变0时暴力修改可持久化trie,最多修改30次

代码:

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<bits/stdc++.h>
using namespace std;
#define y1 y11
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb push_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pli pair<LL, int>
#define pii pair<int, int>
#define piii pair<pii, int>
#define pdd pair<double, double>
#define mem(a, b) memset(a, b, sizeof(a))
#define debug(x) cerr << #x << " = " << x << "\n";
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//head

const int N = 5e4 + 5, M = 2e6 + 10;
const int up = INT_MAX;
int a[N], trie[M][2], root[N], cnt[M], tot = 0, Xor = 0, brute = up, n, q, x, l, r, k;
char s[10];
bool vis[35];
void update(int old, int &rt, int x) {
    rt = ++tot;
    int now = rt;
    for (int i = 30; i >= 0; --i) {
        if(x&(1<<i)) trie[now][0] = trie[old][0], trie[now][1] = ++tot;
        else trie[now][1] = trie[old][1], trie[now][0] = ++tot;
        now = trie[now][(x>>i)&1];
        old = trie[old][(x>>i)&1];
        cnt[now] = cnt[old] + 1;
    }
}
int query(int l, int r, int d, int k) {
    if(d == -1) return 0;
    int ans = 0;
    int id = (Xor>>d)&1;
    int num = cnt[trie[r][id]] - cnt[trie[l][id]];
    ans = 1<<d;
    if(k <= num) ans = query(trie[l][id], trie[r][id], d-1, k);
    else ans += query(trie[l][id^1], trie[r][id^1], d-1, k - num);
    return ans;
}
void reset() {
    tot = 0;
    for (int i = 1; i <= n; ++i) a[i] = a[i]&brute;
    for (int i = 1; i <= n; ++i) update(root[i-1], root[i], a[i]);
}
int main() {
    scanf("%d %d", &n, &q);
    for (int i = 1; i <= n; ++i) scanf("%d", &a[i]);
    reset();
    for (int i = 1; i <= q; ++i) {
        scanf("%s", s);
        if(s[0] == 'X') {
            scanf("%d", &x);
            Xor ^= x;
        }
        else if(s[0] == 'O') {
            brute = up;
            scanf("%d", &x);
            for (int i = 30; i >= 0; --i) {
                if(x&(1<<i)) {
                    if(!vis[i]) {
                        vis[i] = true;
                        brute ^= 1<<i;
                    }
                    Xor |= 1<<i;
                 }
            }
            if(brute != up) reset();
        }
        else if(s[1] == 'n') {
            brute = up;
            scanf("%d", &x);
            for (int i = 30; i >= 0; --i) {
                if(!(x&(1<<i))) {
                    if(!vis[i]) {
                        vis[i] = true;
                        brute ^= 1<<i;
                    }
                    if(Xor&(1<<i)) Xor ^= 1<<i;///放在判断外面
                }
            }
            if(brute != up) reset();
        }
        else {
            scanf("%d %d %d", &l, &r, &k);
            printf("%d\n", query(root[l-1], root[r], 30, k));
        }
    }
    return 0;
}
/*
5 100
0 0 0 0 0
Ask 1 5 1
Or 1
Ask 1 5 1
And 0
Ask 1 5 1
*/
View Code

 

posted @ 2018-04-10 16:49  Wisdom+.+  阅读(345)  评论(0编辑  收藏  举报