可持久化Trie

更好的体验

trie 有两种写法,一种是迭代,一种是递推,我比较喜欢递推,比较好想,代码也短。

带注释的代码

/*
    开始理解可持久化, 这里因为是acwing打卡, 可以放图片了
    有可能会用图片, 尽量打字
    
    可持久化trie, 就是一个trie树但是可以通过不同的开头(root), 变成每个历史状态
    这里就用到上面的图片了, 每次更新trie树, 这条新加入的链一定要, 开成新的, 即使前面
    有相同的的也不行(新加入的不能共用, 不然后面新的部分将更改前面的历史版本, 具体见上面图2)
    
    p是上一版本, q是这一版本, 可持久化, 只针对于上一版本变化
    总体思路就是, 每新加入一个东西, 看看这一层里面和它不同的, 向所有不同的连边(即tr[q][v] == tr[p][v]), 
    相当于复制了一边, 从这个根节点就可以, 通过这条边, 走到这些可以共用的
    对于重复的, 比如cab和cat的c和a, 我们不向它们连边, 新开一个点, 继续进行上面的操作
    
    说实话, 上来整01trie其实并不好, 但可持久化trie的题比较少, 后面我再打一个模版, 对于存单词的
    
    这就是可持久化trie的思路
    
    而对于这题, 有y的讲解, 很易懂, 我这里就粗略讲一下
    
    我们先预处理出来异或前缀和, 然后你能发现一个神奇的性质, 我从l ^ (l + 1) ^ ...  ^ r, 就等于sum[r] ^ sum[l - 1]
    因为 5 ^ 3 ^ 5 == 3对吧, 相当于[1, r] ^ [1, l - 1] = [l, r]; // 这里的[i, j], 是 i ^ (i + 1) ^ ... ^ j得到的数
    而我们是找一个再[l, r](这里指范围)里面的位置p, 然后求[p, n] ^ x, x是题目给出的某一个数
    这玩意就等价于, sum[n] ^ sum[p - 1] ^ x, 然后这里面固定的值是, sum[n] ^ x 我们把这个数设为C
    那么答案就是C ^ sum[p - 1]的最大值, p在[l, r]里面, 那么p - 1就在[l - 1, r - 1]里面;
    也就是说在[l - 1, r - 1]里面选一个异或前缀和, 使得sum[i] ^ C最大
    如果做过01trie的话就会发现, 这个玩意max(sum[i] ^ C)可以直接用01trie求出来(注意题目是让求最大值)
    这就把这个问题转化为了, 01trie, 
    
    这时候就看看这个范围怎么整[l - 1, r - 1](下面简写为[l, r]), 
    r好弄, 我们可以发现, 可以用可持久化trie, 在第r个版本里面搜, 就可以限制住右边界(因为没有r之后的数)
    现在就看左边界, 也好搞, 假设C这个数目前我拆分出来了1, 那么我最好选trie中的0
    这样可以保留这个1, 那么就看看0这棵树里面有没有下标大于l的节点, 而有没有下标大于等于l的节点, 
    就等价于这颗子树里的最大下标是否大于等于l, 就每个数记录一个max_id, 如果这个大于等于l, 就可以选这个0, 
    就可以进到0树里面, 并保留这个1, 否则就走1树
    
    这里求最大值可以像我一样保留1那样组合出这个数, 也可以通过找到那个点max_id[p]用sum[max_id[p]]去异或C
    
    至此此题结束
*/

#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

typedef long long LL;

const int N = 600101, M = 25 * N;

int tr[M][2], n, m;
int max_id[M], root[N], idx;
LL sum[N];

void insert(int i, int k, int p, int q) // 这里可以写成迭代的形式, 但是为了求max_id方便, 所以用递归的形式
{
    if (k < 0) // 说明到最后一个点
    {
        max_id[q] = i; // 那么最大下标就是它自己
        return ;
    }
    
    int v = sum[i] >> k & 1;
    if (p) tr[q][v ^ 1] = tr[p][v ^ 1]; // 连向另一边, 注意是连非当前点(v)的点的边, 并且这个点上一个版本必须有(没有的点我怎么共用, 凭空产生?)
    
    tr[q][v] = ++ idx; // 新加点
    insert(i, k - 1, tr[p][v], tr[q][v]); // 进入下一个点
    max_id[q] = max(max_id[tr[q][0]], max_id[tr[q][1]]); // 在儿子里面求最大下标
}

int query(int root, int C, int L)
{
    int p = root, res = 0; // res是组合出的这个数
    for (int i = 23; i >= 0; i -- )
    {
        int v = C >> i & 1;
        if (max_id[tr[p][!v]] >= L) p = tr[p][!v], res += 1 << i; 
        else p = tr[p][v];
    }

    return res; //  也可以写成 return C ^ sum[max_id[p]]; 因为最后p一定是异或最大的一条线里面的最后的一个点, 那么max_id[p]一定是异或C最大的
}

int main()
{
    cin >> n >> m;
    
    sum[0] = 0; // 开始的0, 因为可能会用到sum[0]; 
    max_id[0] = -1; // 必须小于0, 因为为了防止l - 1 == 0时这个0号历史方案(实际上什么都没有)被选上了, max_id[tr[p][!v]] >= L, 如果小于0的话就可以防止这种情况
    root[0] = ++ idx; // 这就是开始的0点
    insert(0, 23, 0, root[0]); 
    
    for (int i = 1; i <= n; i ++ )
    {
        int w;
        cin >> w;
        sum[i] = sum[i - 1] ^ w;
        root[i] = ++ idx;
        insert(i, 23, root[i - 1], root[i]);
    }
    
    while (m -- )
    {
        char op[2];
        int l, r, x;
        scanf("%s", op);
        if (*op == 'A') 
        {
            scanf("%d", &x);
            root[ ++ n] = ++ idx; // 注意这里 ++ n了长度已经变化了
            sum[n] = sum[n - 1] ^ x;
            insert(n, 23, root[n - 1], root[n]); 
        }
        else
        {
            scanf("%d%d%d", &l, &r, &x);
            printf("%d\n", query(root[r - 1], sum[n] ^ x, l - 1)); // 注意是l - 1, 和 r - 1
        }
    }
    
    return 0;
}

updata 2024,6,26 自己想出的做法

#include <iostream>
#include <cstring>
#include <algorithm>
#include <cmath>

using namespace std;

const int N = 300010 * 2, M = N * 24;

int tr[M][2], idx;
int n, m;
int sum[N], max_id[M];
int root[N];

void insert(int x, int p, int q, int L)
{
    for (int i = 24; i >= 0; i -- )
    {
        int s = x >> i & 1;
        if (p) tr[q][!s] = tr[p][!s];
        
        q = tr[q][s] = ++ idx;
        max_id[q] = max(max_id[p], L);
        p = tr[p][s];
    }
}


int query(int x, int p, int L)
{
    int res = 0;
    for (int i = 24; i >= 0; i -- )
    {
        int s = x >> i & 1;
        if (tr[p][!s] && max_id[tr[p][!s]] >= L) 
        {
            res += 1 << i;
            p = tr[p][!s];
        }
        else p = tr[p][s];
    }
    return res;
}

int main()
{
    cin >> n >> m;
    
    root[0] = ++ idx;
    insert(0, root[0], root[0], 0);
    
    for (int i = 1; i <= n; i ++ ) 
    {
        cin >> sum[i];
        sum[i] ^= sum[i - 1];
        root[i] = ++ idx;
        insert(sum[i], root[i - 1], root[i], i);
    }
    
    
    while (m -- )
    {
        char op[2];
        int l, r, x;
        scanf("%s", op);
        
        if (op[0] == 'A')
        {
            scanf("%d", &x);
            sum[n + 1] = sum[n] ^ x;
            root[n + 1] = ++ idx;
            insert(sum[n + 1], root[n], root[n + 1], n + 1);
            n ++ ;
        }
        else 
        {
            scanf("%d%d%d", &l, &r, &x);
            x ^= sum[n];
            l -- , r -- ;
            printf("%d\n", query(x, root[r], l));
        }
    }
    
    
    return 0;
}

通用模版

/*
    通用的可持久化trie

    具体干什么
    给出n个字符串, 给出m个询问, 每个询问给出一个字符串s, 和一个限制r
    问在前r个字符串里面是否有字符串s
    如果有输出yes, 否则输出no

测试样例
3 4
a ab abc
ab 2
ab 1
abc 3
abc 2

yes
no
yes
no

*/
#include <iostream>
#include <cstring>
#include <algorithm>
#include <ctime>

using namespace std;

const int N = 10010;

int tr[N][27], idx;
int n, m, cnt[N];
int root[N];

void insert(string s, int p, int q) // 这是迭代版本的写法, 下面的扩展模版里面有递归写法
{
    for (int i = 0; s[i]; i ++ )
    {
        int v = s[i] - 'a';
        for (int j = 0; j < 26; j ++ ) // 复制可用边, 这里的复杂度可以以优化
            if (p && j != v) tr[q][j] = tr[p][j];
            
        // memcpy(tr[q], tr[p], sizeof tr[p]); // 相对高效的写法, 当然也快不了太多, 
        q = tr[q][v] = ++ idx; // 走向下一个/建立新点
        p = tr[p][v]; // 走向下一个
    }
    cnt[q] ++ ;
}

bool query(string s, int p)
{
    for (int i = 0; s[i]; i ++ )
    {
        int v = s[i] - 'a';
        if (!tr[p][v]) return false;
        p = tr[p][v];
    }

    return cnt[p];
}

int main()
{
    cin >> n >> m;

    root[0] =  ++ idx; // 也要初始化
	insert("", root[0], root[0]);
    for (int i = 1; i <= n; i ++ )
    {
        string s;
        cin >> s;
        // cout << s << endl;
        root[i] = ++ idx;  
        insert(s, root[i - 1], root[i]);
    }

    while (m -- )
    {
        string s;
        int r;
        cin >> s >> r;
        if (query(s, root[r])) puts("yes");
        else puts("no");
    }
    cout << clock() << endl;
    return 0;
}

扩展的通用模版

/*
    通用的可持久化trie

    具体干什么
    给出n个字符串, 给出m个询问, 每个询问给出一个字符串s, 和限制[l, r]
    问在[l, r]内的字符串里面是否有字符串s
    如果有输出yes, 否则输出no
    
    trie的每层的数量越多, 这玩意的消耗越大, 每次都得乘上以大常数, 所以一般都是01可持久化trie
    这样就常数小
    
测试样例
3 4
a cbc ab
ab 2 3
cb 1 3
aaaa 1 3
ab 1 2

输出样例
yes
no
no
no

*/
#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

const int N = 10010;

int tr[N][27], idx;
int n, m, cnt[N];
int root[N], max_id[N];

// void insert(string s, int p, int q)
// {
//     for (int i = 0; s[i]; i ++ )
//     {
//         int v = s[i] - 'a';
//         for (int j = 0; j < 26; j ++ ) // 复制可用边, 这里的复杂度可以以优化
//             if (p && j != v) tr[q][j] = tr[p][j];
//         q = tr[q][v] = ++ idx; // 走向下一个/建立新点
//         p = tr[p][v]; // 走向下一个
//     }
//     cnt[q] ++ ;
// }

void insert(int i, string s, int k, int p, int q) // 递归版本, 加上求 max_id; 当前的下标, 加入的串, 当前第几个字母, 上一个版本的这层, 这个版本的这层
{
    if (!s[k])
    {
        max_id[q] = i;
        cnt[q] ++ ;
        return ;
    }
    
    int v = s[k] - 'a';
    for (int j = 0; j < 26; j ++ )
        if (p && j != v) tr[q][j] = tr[p][j];  
    tr[q][v] = ++ idx;
    
    insert(i, s, k + 1, tr[p][v], tr[q][v]);
    int maxv = -1;
    for (int j = 0; j < 26; j ++ )
        maxv = max(maxv, max_id[tr[q][j]]);
    max_id[q] = maxv;
}

bool query(int l, string s, int p) // 递归的加上求max_id, 求区间的
{
    for (int i = 0; s[i]; i ++ )
    {
        int v = s[i] - 'a';
        if (!tr[p][v] || max_id[tr[p][v]] < l) return false;
        p = tr[p][v];
    }
    
    return cnt[p];
}

// bool query(string s, int p) // 普通, 
// {
//     for (int i = 0; s[i]; i ++ )
//     {
//         int v = s[i] - 'a';
//         if (!tr[p][v]) return false;
//         p = tr[p][v];
//     }

//     return cnt[p];
// }

int main()
{
    cin >> n >> m;

    root[0] =  ++ idx;
    
    for (int i = 1; i <= n; i ++ )
    {
        string s;
        cin >> s;
        root[i] = ++ idx;  
        insert(i, s, 0, root[i - 1], root[i]);
    }

    while (m -- )
    {
        string s;
        int l, r;
        cin >> s >> l >> r;
        if (query(l, s, root[r])) puts("yes");
        else puts("no");
    }

    return 0;
}

图片

图一
屏幕截图 2023-10-29 205807.png

图二
WIN_20231030_19_38_32_Pr0.jpg

图三
屏幕截图 2023-10-30 143556.png

图四
屏幕截图 2023-10-30 214644.png

posted @ 2024-06-22 16:51  blind5883  阅读(4)  评论(0编辑  收藏  举报