可持久化字典树
在介绍可持久化字典树之前,我们先要说一下01字典树。
01字典树的一个功能是:
- 向当前集合中插入一个正整数。
- 查询当前集合中异或上x后最大的那个数。
接下来介绍一下这两个操作是如何实现的:
- 插入x,我们把x看做一个二进制数,从高位向低位遍历,根据当前位是0或者1,把Tire的枝叶伸向不同的方向。
- 根据异或的性质,我们同样是从高位到低位遍历x的二进制串,若x当前位为1且存在0的枝叶,那就去0那一条枝叶,总之就是尽量与x的当前位取反,这样可以贪心的得到正解。
那如果我想查询区间 \([l, r]\) 的数字与x异或后最大的那个数,怎么办?求异或的这个问题,是通过字典树来解决的,所以区间询问这个问题,就是典中典的可持久化数据结构应用问题了。如果你真正的学会了可持久化线段树的话,那么YY一下也就能改出可持久化字典树的代码了。
来道典题BZOJ4546. CodeChef XRQRS
#include <bits/stdc++.h>
using namespace std;
const int N = 500005, M = 24;
int m;
int rt[N], tot = 0, cnt = 0; //rt:版本数组,tot动态开点,cnt是版本编号。
struct Trie {
int nex[N * M][2];
int sum[N * M];
void insert(int &p, int old, int d, int x) //p当前版本编号,d在枚举二进制。
{
p = ++ tot;
memcpy(nex[p], nex[old], sizeof nex[p]);
sum[p] = sum[old] + 1;
if(d < 0)
return;
if(x >> (d - 1) & 1)
insert(nex[p][1], nex[old][1], d - 1, x);
else
insert(nex[p][0], nex[old][0], d - 1, x);
}
int query_max(int v1, int v2, int d, int x)
{
if(d < 0)
return 0;
int flag = (x >> (d - 1) & 1);
if(sum[nex[v2][flag ^ 1]] - sum[nex[v1][flag ^ 1]] > 0)
return (1 << (d - 1)) + query_max(nex[v1][flag ^ 1], nex[v2][flag ^ 1], d - 1, x);
else
return query_max(nex[v1][flag], nex[v2][flag], d - 1, x);
}
int query_num(int v1, int v2, int d, int now, int x)
{
if(d < 0)
return 0;
if(now + (1 << (d - 1)) <= x)
return sum[nex[v2][0]] - sum[nex[v1][0]] + query_num(nex[v1][1], nex[v2][1], d - 1, now + (1 << (d - 1)), x);
else
return query_num(nex[v1][0], nex[v2][0], d - 1, now, x);
}
int query_kth(int v1, int v2, int d, int k)
{
if(d < 0)
return 0;
int sz = sum[nex[v2][0]] - sum[nex[v1][0]];
if(sz >= k)
return query_kth(nex[v1][0], nex[v2][0], d - 1, k);
else
return (1 << (d - 1)) + query_kth(nex[v1][1], nex[v2][1], d - 1, k - sz);
}
};
Trie T1;
int main()
{
ios::sync_with_stdio(false);cin.tie(0);
int m;
cin >> m;
while(m --)
{
int opt;
cin >> opt;
if(opt == 1)
{
int x; cin >> x;
cnt ++;
T1.insert(rt[cnt], rt[cnt - 1], 19, x);
}
else if(opt == 2)
{
int l, r, x;
cin >> l >> r >> x;
int s = T1.query_max(rt[l - 1], rt[r], 19, x);
cout << (s ^ x) << '\n';
}
else if(opt == 3)
{
int k;
cin >> k;
cnt -= k;
tot = rt[cnt + 1] - 1;
}
else if(opt == 4)
{
int l, r, x;
cin >> l >> r >> x;
cout << T1.query_num(rt[l - 1], rt[r], 19, 0, x) << '\n';
}
else
{
int l, r, k;
cin >> l >> r >> k;
cout << T1.query_kth(rt[l - 1], rt[r], 19, k) << '\n';
}
}
}