【题解】P4592 [TJOI2018]异或(可持久化 01Trie,LCA,倍增)
【题解】P4592 [TJOI2018]异或
题目链接
题意概述
现在有一颗以 \(1\) 为根节点的由 \(n\) 个节点组成的树,节点从 \(1\) 至 \(n\) 编号。树上每个节点上都有一个权值 \(v_i\)。现在有 \(q\) 次操作,操作如下:
-
\(1~x~z\):查询节点 \(x\) 的子树中的节点权值与 \(z\) 异或结果的最大值。
-
\(2~x~y~z\):查询节点 \(x\) 到节点 \(y\) 的简单路径上的节点的权值与 \(z\) 异或结果最大值。
思路分析
考虑从两个角度建立两颗可持久化 01Trie 来求解。
对于操作 1:
如果把树上所有节点按照 dfs 序编号,
那么一个点的子树内的点的编号是连续的,就可以用区间来求。
显然树上编号最小的节点就是根节点,编号最大的节点是 dfs 序最大的节点(最后一个叶子节点)。
我们可以记录每个节点 dfs 序编号 \(dfn_x\),以及每个节点的子树大小 \(siz_x\),然后按照 dfs 序建立一棵可持久化 01Trie。那么字典树上其编号最大的儿子编号是 \(dfn_x+siz_x\)。
所以对于操作 1,答案就是 query1(z,rt1[dfn[x]+sz[x]-1],rt1[dfn[x]-1])
。
(query1 表示的是在第一棵字典树上查询,rt1 是第一棵字典树编号)
对于操作 2:
可以按照根到节点的路径编号新建一棵字典树。
那么 \(x\) 到 \(y\) 的路径,就可以分解为两条链:\(x\) 到 \(lca(x,y)\) 的路径,\(y\) 到 \(lca(x,y)\) 的路径。
对于 \(lca(x,y)\),可以用传统的倍增方法来求解。
对于 \(x\) 到 \(lca(x,y)\) 的路径,答案是 query2(z,rt2[x],rt2[fa[lca][0]])
;
对于 \(y\) 到 \(lca(x,y)\) 的路径,答案是 query2(z,rt2[y],rt2[fa[lca][0]])
。
两者取 \(\max\) 即可。
易错点
-
假如在 query 和 ins 函数时,写的是
query/ins(x,pre,now)
,则主程序调用时,第二个括号里应该是编号小的,第三个括号里是编号大的,因为是参照 \(pre\) 这个版本,然后建 \(now\) 这个版本,所以千万不要弄反。 -
关于数组问题:此题中,\(n\) 的范围是 \(10^5\),按理说一般情况下,开 \(maxn=1e5\),然后再定义 \(son[maxn*32][2]\),但是若 01Trie 的大小本来就预定的是 31 位,那么开 32 倍可能会 RE,要开 33 倍,也就是说,保险起见,若 01Trie 位数为 \(n\) 位,则开 \((n+2) \times maxn\) 是最为稳妥的。
对于此题,实际上并不需要 31 位,30 位就够了,因为题目中说:\(1 \leq v_i, z \lt 2^{30}\)。
代码实现
//luoguP4592
#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
using namespace std;
const int maxn=1e5+10;
int v[maxn],num;
int son1[maxn*32][2],son2[maxn*32][2],vis1[maxn*32],vis2[maxn*32];
int rt1[maxn*32],rt2[maxn*32],tot1,tot2;
int dfn[maxn*32],sz[maxn*32],dep[maxn*32],fa[maxn*32][32];
int sum1[maxn*32],sum2[maxn*32];
basic_string<int>edge[maxn<<1];
inline int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar(); }
while(ch>='0'&&ch<='9'){x=x*10+ch-48;ch=getchar(); }
return x*f;
}
void ins1(int x,int pre,int now)
{
// cout<<tot1<<endl;
for(int i=30;i>=0;i--)
{
sum1[now]=sum1[pre]+1;
int k=(x>>i)&1;
if(!son1[now][k])son1[now][k]=++tot1;
son1[now][k^1]=son1[pre][k^1];
now=son1[now][k];
pre=son1[pre][k];
}
sum1[now]=sum1[pre]+1;
// cout<<tot1<<endl;
return ;
}
void ins2(int x,int pre,int now)
{
// cout<<tot2<<endl;
for(int i=30;i>=0;i--)
{
sum2[now]=sum2[pre]+1;
int k=(x>>i)&1;
if(!son2[now][k])son2[now][k]=++tot2;//bug 2:tot1
son2[now][k^1]=son2[pre][k^1];
// cout<<"ex "<<sum2[now]<<endl;
now=son2[now][k];
pre=son2[pre][k];
}
sum2[now]=sum2[pre]+1;
// cout<<tot2<<endl;
return ;
}
int query1(int x,int pre,int now)
{
int ret=0;
for(int i=30;i>=0;i--)
{
int k=(x>>i)&1;
if(sum1[son1[pre][k^1]]-sum1[son1[now][k^1]]>=1)
{
ret|=(1ll<<i);
pre=son1[pre][k^1];
now=son1[now][k^1];
}
else
{
pre=son1[pre][k];
now=son1[now][k];
}
}
return ret;
}
int query2(int x,int pre,int now)
{
int ret=0;
for(int i=30;i>=0;i--)
{
int k=(x>>i)&1;
// cout<<k<<" "<<(k^1)<<" "<<pre<<" "<<now<<endl;
// cout<<i<<" "<<son2[pre][k^1]<<" "<<son2[now][k^1]<<" "<<sum2[son2[pre][k^1]]<<" "<<sum2[son2[now][k^1]]<<endl;
if(sum2[son2[pre][k^1]]-sum2[son2[now][k^1]]>=1)
{
ret|=(1ll<<i);
pre=son2[pre][k^1];
now=son2[now][k^1];
}
else
{
pre=son2[pre][k];
now=son2[now][k];
}
// cout<<ret<<endl;
}
return ret;
}
void dfs(int x,int fath)
{
dfn[x]=++num;
rt1[num]=++tot1;
ins1(v[x],rt1[num-1],rt1[num]);//bug 1:把 num-1 和 num 高反了。
rt2[x]=++tot2;
ins2(v[x],rt2[fath],rt2[x]);
dep[x]=dep[fath]+1;
fa[x][0]=fath;
sz[x]=1;
for(int i=1;i<=30;i++)fa[x][i]=fa[fa[x][i-1]][i-1];
for(int y:edge[x])
{
if(y==fath)continue;
dfs(y,x);
sz[x]+=sz[y];
}
}
int LCA(int x,int y)
{
if(dep[x]<dep[y])swap(x,y);
for(int i=0;i<=20;i++)
{
if((dep[x]-dep[y])&(1<<i))x=fa[x][i];
}
if(x==y)return x;
for(int i=20;i>=0;i--)
{
if(fa[x][i]!=fa[y][i]){x=fa[x][i];y=fa[y][i];}
}
return fa[x][0];
}
int main()
{
int n,q;
n=read();q=read();
for(int i=1;i<=n;i++)v[i]=read();
for(int i=1;i<n;i++)
{
int u,v;
u=read();v=read();
edge[u]+=v;edge[v]+=u;
}
dfs(1,0);
// for(int i=1;i<=n;i++)cout<<fa[i][0]<<endl;
// cout<<tot1<<" "<<tot2<<endl;
// for(int i=1;i<=tot2;i++)cout<<sum2[i]<<endl;
while(q--)
{
int opt,x,y,z;
opt=read();
if(opt==1)
{
x=read();z=read();
// cout<<rt1[dfn[x]+sz[x]-1]<<endl;
// cout<<rt1[dfn[x]-1]<<endl;
cout<<query1(z,rt1[dfn[x]+sz[x]-1],rt1[dfn[x]-1])<<'\n';
}
else
{
x=read();y=read();z=read();
int lca=LCA(x,y);
// cout<<fa[lca][0]<<endl;
// cout<<rt2[fa[lca][0]]<<endl;
// cout<<rt2[x]<<" "<<rt2[y]<<endl;
int sum1=query2(z,rt2[x],rt2[fa[lca][0]]);
int sum2=query2(z,rt2[y],rt2[fa[lca][0]]);//bug两个括号里面写反了。
// cout<<sum1<<" "<<sum2<<endl;
cout<<max(sum1,sum2)<<'\n';
}
}
return 0;
}
/*感觉需要梳理一下思路:
如果把树上所有节点按照 dfs 序编号,
那么一个点的子树内的点的编号是连续的,那么就可以用区间来求。
显然树上编号最小的节点就是根节点,编号最大的节点是 dfs 序最大的节点(最后一个叶子节点)。
对于第二种操作,可以按照根到节点的路径编号新建一棵字典树。
那么 x 到 y 的路径,就可以分解为两条链:x 到 lca 的路径,y 到 lca 的路径。
分别计算:
x 到 lca:根到 x - 根到 fa[lca]
y 到 lca:根到 y - 根到 fa[lca]
lca 倍增求解即可。
*/