【题解】P4592 [TJOI2018]异或(可持久化 01Trie,LCA,倍增)

【题解】P4592 [TJOI2018]异或

题目链接

P4592 [TJOI2018]异或 - 洛谷

题意概述

现在有一颗以 \(1\) 为根节点的由 \(n\) 个节点组成的树,节点从 \(1\) 至 \(n\) 编号。树上每个节点上都有一个权值 \(v_i\)。现在有 \(q\) 次操作,操作如下:

  • \(1~x~z\):查询节点 \(x\) 的子树中的节点权值与 \(z\) 异或结果的最大值。

  • \(2~x~y~z\):查询节点 \(x\) 到节点 \(y\) 的简单路径上的节点的权值与 \(z\) 异或结果最大值。

思路分析

考虑从两个角度建立两颗可持久化 01Trie 来求解。

对于操作 1:

如果把树上所有节点按照 dfs 序编号,

那么一个点的子树内的点的编号是连续的,就可以用区间来求。

显然树上编号最小的节点就是根节点,编号最大的节点是 dfs 序最大的节点(最后一个叶子节点)。

我们可以记录每个节点 dfs 序编号 \(dfn_x\),以及每个节点的子树大小 \(siz_x\),然后按照 dfs 序建立一棵可持久化 01Trie。那么字典树上其编号最大的儿子编号是 \(dfn_x+siz_x\)

所以对于操作 1,答案就是 query1(z,rt1[dfn[x]+sz[x]-1],rt1[dfn[x]-1])

(query1 表示的是在第一棵字典树上查询,rt1 是第一棵字典树编号)

对于操作 2:

可以按照根到节点的路径编号新建一棵字典树。

那么 \(x\)\(y\) 的路径,就可以分解为两条链:\(x\)\(lca(x,y)\) 的路径,\(y\)\(lca(x,y)\) 的路径。

对于 \(lca(x,y)\),可以用传统的倍增方法来求解。

对于 \(x\)\(lca(x,y)\) 的路径,答案是 query2(z,rt2[x],rt2[fa[lca][0]])

对于 \(y\)\(lca(x,y)\) 的路径,答案是 query2(z,rt2[y],rt2[fa[lca][0]])

两者取 \(\max\) 即可。

易错点

  • 假如在 query 和 ins 函数时,写的是 query/ins(x,pre,now),则主程序调用时,第二个括号里应该是编号小的,第三个括号里是编号大的,因为是参照 \(pre\) 这个版本,然后建 \(now\) 这个版本,所以千万不要弄反。

  • 关于数组问题:此题中,\(n\) 的范围是 \(10^5\),按理说一般情况下,开 \(maxn=1e5\),然后再定义 \(son[maxn*32][2]\),但是若 01Trie 的大小本来就预定的是 31 位,那么开 32 倍可能会 RE,要开 33 倍,也就是说,保险起见,若 01Trie 位数为 \(n\) 位,则开 \((n+2) \times maxn\) 是最为稳妥的。

    对于此题,实际上并不需要 31 位,30 位就够了,因为题目中说:\(1 \leq v_i, z \lt 2^{30}\)

代码实现

//luoguP4592
#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
using namespace std;
const int maxn=1e5+10;
int v[maxn],num;
int son1[maxn*32][2],son2[maxn*32][2],vis1[maxn*32],vis2[maxn*32];
int rt1[maxn*32],rt2[maxn*32],tot1,tot2;
int dfn[maxn*32],sz[maxn*32],dep[maxn*32],fa[maxn*32][32];
int sum1[maxn*32],sum2[maxn*32];

basic_string<int>edge[maxn<<1];

inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar(); }
    while(ch>='0'&&ch<='9'){x=x*10+ch-48;ch=getchar(); }
    return x*f;
 } 

void ins1(int x,int pre,int now)
{
//    cout<<tot1<<endl;
    for(int i=30;i>=0;i--)
    {
        sum1[now]=sum1[pre]+1;
        int k=(x>>i)&1;
        if(!son1[now][k])son1[now][k]=++tot1;
        son1[now][k^1]=son1[pre][k^1];
        now=son1[now][k];
        pre=son1[pre][k];
     } 
    sum1[now]=sum1[pre]+1;
//    cout<<tot1<<endl;
    return ;
}

void ins2(int x,int pre,int now)
{
//    cout<<tot2<<endl;
    for(int i=30;i>=0;i--)
    {
        sum2[now]=sum2[pre]+1;
        int k=(x>>i)&1;
        if(!son2[now][k])son2[now][k]=++tot2;//bug 2:tot1
        son2[now][k^1]=son2[pre][k^1];
//        cout<<"ex "<<sum2[now]<<endl; 
        now=son2[now][k];
        pre=son2[pre][k];
     } 
    sum2[now]=sum2[pre]+1;
//    cout<<tot2<<endl;
    return ;
}

int query1(int x,int pre,int now)
{
    int ret=0;
    for(int i=30;i>=0;i--)
    {
        int k=(x>>i)&1;
        if(sum1[son1[pre][k^1]]-sum1[son1[now][k^1]]>=1)
        {
            ret|=(1ll<<i);
            pre=son1[pre][k^1];
            now=son1[now][k^1];
        }
        else 
        {
            pre=son1[pre][k];
            now=son1[now][k];
        }
    }
    return ret;
}

int query2(int x,int pre,int now)
{
    int ret=0;
    for(int i=30;i>=0;i--)
    {
        int k=(x>>i)&1;
//        cout<<k<<" "<<(k^1)<<" "<<pre<<" "<<now<<endl;
//        cout<<i<<" "<<son2[pre][k^1]<<" "<<son2[now][k^1]<<" "<<sum2[son2[pre][k^1]]<<" "<<sum2[son2[now][k^1]]<<endl;
        if(sum2[son2[pre][k^1]]-sum2[son2[now][k^1]]>=1)
        {
            ret|=(1ll<<i);
            pre=son2[pre][k^1];
            now=son2[now][k^1];
        }
        else 
        {
            pre=son2[pre][k];
            now=son2[now][k];
        }
//        cout<<ret<<endl;
    }
    return ret;
}


void dfs(int x,int fath)
{
    dfn[x]=++num;
    rt1[num]=++tot1;
    ins1(v[x],rt1[num-1],rt1[num]);//bug 1:把 num-1 和 num 高反了。 
    rt2[x]=++tot2;
    ins2(v[x],rt2[fath],rt2[x]);
    dep[x]=dep[fath]+1;
    fa[x][0]=fath;
    sz[x]=1;
    for(int i=1;i<=30;i++)fa[x][i]=fa[fa[x][i-1]][i-1];
    for(int y:edge[x])
    {
        if(y==fath)continue;
        dfs(y,x);
        sz[x]+=sz[y];
    }
}

int LCA(int x,int y)
{
    if(dep[x]<dep[y])swap(x,y);
    for(int i=0;i<=20;i++)
    {
        if((dep[x]-dep[y])&(1<<i))x=fa[x][i];
    }
    if(x==y)return x;
    for(int i=20;i>=0;i--)
    {
        if(fa[x][i]!=fa[y][i]){x=fa[x][i];y=fa[y][i];}
    }
    return fa[x][0]; 
}

int main()
{
    int n,q;
    n=read();q=read(); 
    for(int i=1;i<=n;i++)v[i]=read();
    for(int i=1;i<n;i++)
    {
        int u,v;
        u=read();v=read();
        edge[u]+=v;edge[v]+=u;
    }
    dfs(1,0);
//    for(int i=1;i<=n;i++)cout<<fa[i][0]<<endl;
//    cout<<tot1<<" "<<tot2<<endl;
//    for(int i=1;i<=tot2;i++)cout<<sum2[i]<<endl;
    while(q--)
    {
        int opt,x,y,z;
        opt=read();
        if(opt==1)
        {
            x=read();z=read();
//            cout<<rt1[dfn[x]+sz[x]-1]<<endl;
//            cout<<rt1[dfn[x]-1]<<endl;
            cout<<query1(z,rt1[dfn[x]+sz[x]-1],rt1[dfn[x]-1])<<'\n';
        }
        else
        {
            x=read();y=read();z=read();
            int lca=LCA(x,y);
//            cout<<fa[lca][0]<<endl;
//            cout<<rt2[fa[lca][0]]<<endl;
//            cout<<rt2[x]<<" "<<rt2[y]<<endl;
            int sum1=query2(z,rt2[x],rt2[fa[lca][0]]);
            int sum2=query2(z,rt2[y],rt2[fa[lca][0]]);//bug两个括号里面写反了。 
//            cout<<sum1<<" "<<sum2<<endl;
            cout<<max(sum1,sum2)<<'\n';
        }
    }
    return 0;
 } 
/*感觉需要梳理一下思路:
如果把树上所有节点按照 dfs 序编号,
那么一个点的子树内的点的编号是连续的,那么就可以用区间来求。
显然树上编号最小的节点就是根节点,编号最大的节点是 dfs 序最大的节点(最后一个叶子节点)。
对于第二种操作,可以按照根到节点的路径编号新建一棵字典树。
那么 x 到 y 的路径,就可以分解为两条链:x 到 lca 的路径,y 到 lca 的路径。
分别计算:
x 到 lca:根到 x - 根到 fa[lca]
y 到 lca:根到 y - 根到 fa[lca]
lca 倍增求解即可。 
*/ 
posted @ 2022-06-25 17:47  向日葵Reta  阅读(33)  评论(0编辑  收藏  举报