【字符串】Trie树
trie树
学习资料:OI Wiki
模板:
struct Tire{
int nxt[maxn][26],cnt;
bool exit[maxn];//以这个点结尾的字符串是否存在
void insert(char s[],int len)
{
int p=0;
for(int i=0;i<len;i++){
int c=s[i]-'a';
if(!nxt[p][c])nxt[p][c]=++cnt;
p=nxt[p][c];
}
exit[p]=1;
}
bool find(char s[],int len){
int p=0;
for(int i=0;i<len;i++){
int c=s[i]-'a';
if(!nxt[p][c])return false;
p=nxt[p][c];
}
return exit[p];
}
}tr;
例题和作用:
Trie中的节点表示的是某个模式串的前缀,也称状态。Tire的边就是状态的转移。
1,检索字符串,查找一个字符串是否出现过
例题1:luoguP2580
题意:
首先给一个整数 \(n\) 和 \(n\) 个字符串,之后一个整数 \(m\) 和 \(m\) 个字符串,要求对这 \(m\) 个字符串输出 \(m\) 行:
- 如果该字符串不存在于 \(n\) 个给定字符串中,则输出 "
WRONG
"; - 否则,如果该字符串第一次出现在 \(m\) 个字符串中,则输出 “
OK
”; - 否则,输出 "
REPEAT
"。
解:
hash可解,tire也可解。此处用tire
#include<bits/stdc++.h>
#define reg register
using namespace std;
typedef long long ll;
const int maxn=5e5+5;
struct Tire{
int nxt[maxn][27],cnt;
bool exit[maxn],vis[maxn];
void insert(char s[],int len)
{
int p=0;
for(int i=0;i<len;i++){
int c=s[i]-'a';
if(!nxt[p][c])nxt[p][c]=++cnt;
p=nxt[p][c];
}
exit[p]=1;
}
int find(char s[],int len){
int p=0;
for(int i=0;i<len;i++){
int c=s[i]-'a';
if(!nxt[p][c])return 0;
p=nxt[p][c];
}
return exit[p]?(vis[p]==1?-1:(vis[p]=1)):0;
}
}tr;
int main()
{
int n,m;char s[55];
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%s",s);
tr.insert(s,strlen(s));
}
scanf("%d",&m);
while(m--)
{
scanf("%s",s);
int res=tr.find(s,strlen(s));
if(!res)puts("WRONG");
else if(res==1)puts("OK");
else puts("REPEAT");
}
}
2,维护异或和
例题2:luoguP6018
题意:给定一棵 \(n\) 个节点的无向无根树,每个节点有权值 \(a_i\) ,之后 \(m\) 个操作,操作有三种:
- 1 x :将与节点 \(x\) 距离为 \(1\) 的节点(即,与 \(x\) 直接相连的节点)的权值增 \(1\) 。
- 2 x v:将节点 \(x\) 的权值减 \(v\) .
- 3 x :询问与节点 \(x\) 距离为 \(1\) 的节点的权值异或和。
对每个询问输出答案。保证任意时刻每个节点的权值非负。\(n\leq5\times10^5,m\leq5\times10^5,0\leq a_i\leq10^5,1\leq x\leq n\) 。
解:
对于 \(1\) 操作,每个点建一棵 01tire ,支持全局加1和维护异或和。
对于 \(2\) 操作,\(\mathcal{O(1)}\) 维护修改。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=5e5+5;
const int MAXN=2e7+5;
const int MA=23;
int nxt[MAXN][2],cnt,num[MAXN],dep[MAXN];
struct Tire{
int dval[MA];
void insert(int &rt,int x,int d)
{
if(d>=MA)return;
if(!rt)rt=++cnt;
if(x&1){
insert(nxt[rt][1],x>>1,d+1);
if(nxt[rt][1])num[nxt[rt][1]]++,dval[d]++;
}else{
insert(nxt[rt][0],x>>1,d+1);
if(nxt[rt][0])num[nxt[rt][0]]++;
}
dep[rt]=d;
}
void erase(int rt,int x,int d)
{
if(d>=MA)return;
if(x&1){
insert(nxt[rt][1],x>>1,d+1);
if(nxt[rt][1])num[nxt[rt][1]]--,dval[d]--;
}else{
insert(nxt[rt][0],x>>1,d+1);
if(nxt[rt][0])num[nxt[rt][0]]--;
}
}
void addone(int rt)
{
dval[dep[rt]]-=num[nxt[rt][1]];
swap(nxt[rt][1],nxt[rt][0]);
dval[dep[rt]]+=num[nxt[rt][1]];
if(nxt[rt][0])addone(nxt[rt][0]);
}
int getval(){
int res=0;
for(int i=0;i<MA;i++)res+=(dval[i]&1)<<i;
return res;
}
}tr[maxn];
int root[maxn];
int head[maxn],nxtt[maxn<<1],to[maxn<<1],pcnt;
inline void add(int u,int v){
to[++pcnt]=v;nxtt[pcnt]=head[u];head[u]=pcnt;
}
int fa[maxn];
void dfs(int u,int f){
fa[u]=f;
for(int i=head[u];i;i=nxtt[i])
if(to[i]!=f)dfs(to[i],u);
}
int a[maxn];
int ad[maxn],dv[maxn];
void update1(int x){
if(!fa[x])return;
int tmp=a[x]+ad[fa[x]]+dv[x];
tr[fa[x]].erase(root[fa[x]],tmp,0);
tr[fa[x]].insert(root[fa[x]],tmp+1,0);
}
void update2(int x,int v)
{
if(!fa[x])return;
int tmp=a[x]+ad[fa[x]]+dv[x];
tr[fa[x]].erase(root[fa[x]],tmp,0);
tr[fa[x]].insert(root[fa[x]],tmp-v,0);
}
int main()
{
int n,m,u,v;
scanf("%d%d",&n,&m);
for(int i=1;i<n;i++)
{
scanf("%d%d",&u,&v);
add(u,v);add(v,u);
}
for(int i=1;i<=n;i++)scanf("%d",&a[i]);
dfs(1,0);
for(u=1;u<=n;u++)
for(int i=head[u];i;i=nxtt[i])
if(to[i]!=fa[u])tr[u].insert(root[u],a[to[i]],0);
int op;
while(m--)
{
scanf("%d%d",&op,&u);
if(op==1){
ad[u]++;tr[u].addone(root[u]);
if(fa[u])update1(fa[u]);dv[fa[u]]++;
}else if(op==2){
scanf("%d",&v);update2(u,v);dv[u]-=v;
}else{
printf("%d\n",(a[fa[u]]+dv[fa[u]]+ad[fa[fa[u]]])^tr[u].getval());
}
}
}