P3401 洛谷树

P3401 洛谷树

分析

很有意思的题目,我们来从头分析。

查询操作

首先不难分析,为了要求一段路径的异或值,我们采用类似前缀和的方式,预处理出来,从根节点到每个节点的边权异或值。将该数组设为sum

这样,我们想求u,v之间路径的异或值时,只需要求sum[u] ^ sum[v] 就可以得到从u-v之间的路径异或值。

但是,题目要求的是求u-v路径中的所有子路径的异或值的和。这就很麻烦了。

我们不难想到第一步转化,因为我们预处理的从根节点到任意节点的异或值。因此u-v路径中的所有子路径的异或值的和转化为了u-v之间任意两点的sum的异或值的和。接下来我们思考如何快速的算出转化后的问题。此时我们注意到异或的重要特点,即每一位的独立性,每一位我们都可以进行独立的计算

因此,我们想到了第二步转化,将u-v之间任意两点的sum的异或值的和转化为枚举每一位,u-v任意两点之间sum的异或值该位为1的数量

此时几乎就已经解决了,接下来我们想想如何求转化完的最后问题。

因为异或值为1一定是由一个0与一个1组成,因此最后的问题,即为求枚举每一位,对于从u-v中的所有异或值,该位为1的数量,以及该位为0的数量,两者相乘即为该位为1的所有异或值数量

这样就解决了。我们可以用树剖+线段树解决。另外其中,用到了一个小技巧,在Can you answer these queries III,求区间最大子段和中用到过。我们返回的时候,需要返回时很多值,因此可以直接返回结构体。将push操作单独写出来即可。详情见代码。

另外,我需要特意强调一下免得有人跟我一样犯蠢的错误。

这题,我们虽然操作的是边权,并且也将边权化了点权。但树中每个节点利用的值是到该点的路径异或值,因此我们需要操作的就是每个节点,不需要考虑找到LCA后,还要将id[LCA]+1

修改操作

这个就比较简单了。我们设原值为w[u],要修改为的值为y。枚举每一位,如果对于该位w[u]与y不相同,则在u的子树下的所有值的sum的当前位都要取反

for(int i=0;i<10;i++)
    if(((w[u]^y)>>i)&1)
        modify(1,id[u],id[u]+sz[u]-1,i);

Ac_code

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 3e4 + 10;
struct Node
{
    int l,r,tag,num0,num1;
}tr[N<<2][12];
int h[N],e[N<<1],ne[N<<1],w[N<<1],idx;
int sz[N],son[N],fa[N],dep[N];
int top[N],id[N],nw[N],val[N],sum[N],ts;
int n,m;

void add(int a,int b,int c)
{
    e[idx] = b,ne[idx] = h[a],w[idx] = c,h[a] = idx++;
}

void dfs1(int u,int pa,int depth)
{
    sz[u] = 1,fa[u] = pa,dep[u] = depth;
    for(int i=h[u];~i;i=ne[i])
    {
        int j = e[i];
        if(j==pa) continue;
        sum[j] = sum[u]^w[i];
        val[j] = w[i];
        dfs1(j,u,depth+1);
        sz[u] += sz[j];
        if(sz[j]>sz[son[u]]) son[u] = j;
    }
}

void dfs2(int u,int tp)
{
    top[u] = tp,id[u] = ++ts;
    nw[ts] = sum[u];
    if(!son[u]) return ;
    dfs2(son[u],tp);
    for(int i=h[u];~i;i=ne[i])
    {
        int j = e[i];
        if(j==fa[u]||j==son[u]) continue;
        dfs2(j,j);
    }
}

void push(Node &u,Node l,Node r)
{
    u.num0 = l.num0 + r.num0;
    u.num1 = l.num1 + r.num1;
}

void pushup(int u,int k)
{
    push(tr[u][k],tr[u<<1][k],tr[u<<1|1][k]);
}

void pushdown(int u,int k)
{
    auto &root = tr[u][k],&left = tr[u<<1][k],&right = tr[u<<1|1][k];
    if(root.tag)
    {
        swap(left.num1,left.num0);
        left.tag ^= 1;
        swap(right.num1,right.num0);
        right.tag ^= 1;
        root.tag = 0;
    }
}

void build(int u,int l,int r,int k)
{
    tr[u][k] = {l,r,0,0,0};
    if(l==r)
    {
        if((nw[l]>>k)&1) tr[u][k].num1 = 1;
        else tr[u][k].num0 = 1;
        return ;
    }
    int mid = l + r >> 1;
    build(u<<1,l,mid,k),build(u<<1|1,mid+1,r,k);
    pushup(u,k);
}

Node query(int u,int l,int r,int k)
{
    if(l<=tr[u][k].l&&tr[u][k].r<=r) return tr[u][k];
    pushdown(u,k);
    int mid = tr[u][k].l + tr[u][k].r >> 1;
    Node res = {0,0,0,0,0};
    if(l<=mid) push(res,res,query(u<<1,l,r,k));
    if(r>mid) push(res,res,query(u<<1|1,l,r,k));
    return res;
}

void modify(int u,int l,int r,int k)
{
    if(l<=tr[u][k].l&&tr[u][k].r<=r) 
    {
        swap(tr[u][k].num1,tr[u][k].num0);
        tr[u][k].tag ^= 1;
        return ;
    }
    pushdown(u,k);
    int mid = tr[u][k].l + tr[u][k].r >> 1;
    if(l<=mid) modify(u<<1,l,r,k);
    if(r>mid) modify(u<<1|1,l,r,k);
    pushup(u,k);
    return ;
}

int main()
{
    scanf("%d%d",&n,&m);
    memset(h,-1,sizeof h);
    for(int i=0;i<n-1;i++) 
    {
        int a,b,c;scanf("%d%d%d",&a,&b,&c);
        add(a,b,c),add(b,a,c);
    }
    dfs1(1,-1,1);
    dfs2(1,1);
    for(int i=0;i<10;i++) build(1,1,n,i);
    while(m--)
    {
        int op,u,v,c;
        scanf("%d%d%d",&op,&u,&v);
        if(op==1)
        {
            LL res = 0;
            int ub = u,vb = v;
            for(int i=0;i<10;i++)
            {
                u = ub,v = vb;
                Node t = {0,0,0,0,0};
                while(top[u]!=top[v])
                {
                    if(dep[top[u]]<dep[top[v]]) swap(u,v);
                    push(t,t,query(1,id[top[u]],id[u],i));
                    u = fa[top[u]];
                }
                if(dep[u]<dep[v]) swap(u,v);
                push(t,t,query(1,id[v],id[u],i));
                res += (1ll<<i)*t.num1*t.num0;
            }
            printf("%lld\n",res);
        }
        else 
        {
            scanf("%d",&c);
            if(dep[u]<dep[v]) swap(u,v);
            for(int i=0;i<10;i++)
                if(((c^val[u])>>i)&1)
                    modify(1,id[u],id[u]+sz[u]-1,i);
            val[u] = c;
        }
    }
    return 0;
}
posted @ 2022-04-08 11:56  艾特玖  阅读(110)  评论(0编辑  收藏  举报