学习笔记-字典树
字典树一般有两个作用(我学到的),一个是查询单词的出现,一个是计算最大异或值。
字典树的ch数组该如何理解?
其实ch[p][j]指的是从p是否有一条值为为j的边到下一个点,如果ch[p][j]为0,就是没有。
例题1
luogu P2580
https://www.luogu.com.cn/problem/P2580
这题就是存字串的裸题,唯一要处理的细节就是点名两次该如何处理,我这里用了一个有点蠢的方法,记录了每个字符串点的次数,以及是否点过,如果次数大于2并且再一次点到就是重复点名。还有这里的cnt数组记录的是每个字符串结尾的次数。有些题种记录的是每个字符串出现的次数。
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
const int mod = 998244353, INF = 1 << 30;
const int N = 5e5 + 10;
int ch[N][26], idx, cnt[N];
map<string, bool> vis;
map<string, int> hv;
void insert(string &s)
{
int len = s.size();
int p = 0;
for (int i = 0; i < len; i ++){
int j = s[i] - 'a';
if (!ch[p][j]) ch[p][j] = ++ idx;
p = ch[p][j];
}
cnt[p] ++;
}
int query(string &s)
{
int len = s.size();
int p = 0;
for (int i = 0; i < len; i ++){
int j = s[i] - 'a';
if (!ch[p][j]) return 0;
p = ch[p][j];
}
return cnt[p];
}
int main()
{
cin.tie(nullptr);
ios::sync_with_stdio(false);
int n;
cin >> n;
for (int i = 1; i <= n; i ++){
string s;
cin >> s;
hv[s] ++;
}
int q;
cin >> q;
while (q --){
string s;
cin >> s;
if (hv[s] && !vis[s]) {cout << "OK" << endl; vis[s] = 1;}
else if (hv[s] && vis[s]) cout << "REPEAT" << endl;
else cout << "WRONG" << endl;
}
return 0;
}
思考: 我们有插入操作,那如何实现删除操作?
其实我们可以用懒删除,用一个cnt计数,这里的cnt和例1种的cnt不太一样,这个是每个节点都要计数一次,例1中是每个单词的尾部节点计数一次,删除时将每个节点的cnt减去一次,查询的时候如果cnt为0就说明节点已经删除掉了。
例题2
luogu P4551
https://www.luogu.com.cn/problem/P4551
这题是求最大的异或值,只是转移到树上,我们假设异或值最大的路径是u -> v,根据异或前缀和的思想,我们设点x到根节点的异或路径和是sum[x], u, v到最近公共祖先的异或值为s[u], s[v], s[u] ^ s[v]就是这题的答案。
s[u] ^ s[v] = s[u] ^ s[v] ^ sum[lca[u, v]] ^ sum[lca[u, v]] = (s[u] ^ sum[lca(u, v)]) ^ (s[v] ^ sum[lca(u, v)]) = sum[u] ^ sum[v]
所以我们只需要一次dfs就能算出所有节点到根节点的异或路径和,并查询所有异或路径和的最大异或值,用01trie维护。
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
const int mod = 998244353, INF = 1 << 30;
const int N = 1e5 + 10;
vector<PII> e[N];
int ch[N * 31][2], idx, sum[N];
int ans;
void insert(int x)
{
int p = 0;
for (int i = 30; i >= 0; i --){
int j = (x >> i) & 1;
if (!ch[p][j]) ch[p][j] = ++ idx;
p = ch[p][j];
}
}
int query(int x)
{
int p = 0, res = 0;
for (int i = 30; i >= 0; i --){
int j = (x >> i) & 1;
if (ch[p][!j]){
res += (1 << i);
p = ch[p][!j];
}
else{
p = ch[p][j];
}
}
return res;
}
void dfs(int u, int fa)
{
for (auto x: e[u]){
int v = x.first, w = x.second;
if (v == fa) continue;
sum[v] = sum[u] ^ w;
dfs(v, u);
}
}
int main()
{
cin.tie(nullptr);
ios::sync_with_stdio(false);
int n;
cin >> n;
for (int i = 1; i < n; i ++) {
int u, v, w;
cin >> u >> v >> w;
e[u].push_back({v, w});
e[v].push_back({u, w});
}
dfs(1, -1);
for (int i = 2; i <= n; i ++) insert(sum[i]);
for (int i = 1; i <= n; i ++) ans = max(query(sum[i]), ans);
cout << ans << endl;
return 0;
}
例题3
CF1895D XOR Construction
https://codeforces.com/contest/1895/problem/D
题意:给出一个n-1长度的数组a,要构造一个b数组,满足要求:1.数组b是0 ~ n-1的排列 2.b[i] ^ b[i + 1] = a[i]
手推一下,可以发现 \(a_{1} \oplus a_{2} \oplus .... \oplus a_{n - 1} = b_{1} \oplus b_{2} \oplus b_{2} \oplus b_{3} \oplus .... \oplus b_{n - 1} \oplus b_{n}\)
即\(a_{1} \oplus a_{2} \oplus .... \oplus a_{n - 1} = b_{1} \oplus b_{n}\),可以看出我们只要求出\(b_{1}\)的值,b就能求出来。
我们设a的异或前缀和为sum[n];
我们发现题目中说保证有解,而且解的个数可能不为1。
用a的异或前缀和建一颗字典树,我们的所有值都是0 ~ n - 1的。
所以我们枚举\(a_{1}\) 从[0, n - 1], 如果与字典树中的最大异或和刚好是n - 1,这个\(a_{1}\)就是符合要求的。
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
const int mod = 998244353, INF = 1 << 30;
const int N = 2e5 + 10;
int a[N], sum[N];
int ch[N * 31][2], idx;
void insert(int x)
{
int p = 0;
for (int i = 30; i >= 0; i --){
int j = (x >> i) & 1;
if (!ch[p][j]) ch[p][j] = ++ idx;
p = ch[p][j];
}
}
int query(int x)
{
int p = 0, res = 0;
for (int i = 30; i >= 0; i --){
int j = (x >> i) & 1;
if (ch[p][!j]){
res += (1 << i);
p = ch[p][!j];
}
else{
p = ch[p][j];
}
}
return res;
}
int main()
{
cin.tie(nullptr);
ios::sync_with_stdio(false);
int n;
cin >> n;
for (int i = 1; i < n; i ++){
cin >> a[i];
sum[i] = a[i];
sum[i] ^= sum[i - 1];
}
for (int i = 1; i < n; i ++) insert(sum[i]);
int fst = -1;
for (int i = 0; i < n; i ++){
if (query(i) == n - 1) {
fst = i;
break;
}
}
if (fst == -1) fst = n - 1;
cout << fst << ' ';
for (int i = 1; i < n; i ++){
cout << (fst ^ sum[i]) << ' ';
}
return 0;
}