Infiniti

   :: 首页  :: 新随笔  ::  ::  :: 管理

题意

后两个求和符号代表的是有多少异或值为0的路径。
前两个符号代表有多少路径包含异或值为0的路径。
即每个权值为0的路径对答案的贡献为 有多少路径包含当前路径 ,所有的贡献加起来就是答案。


思路

点分治,权值太大,桶只能用map了。
x与y之间的路径(x与y间权值异或为0)对答案的贡献是
(x可以扩展的点数)**(y可以扩展的点数)。
1
如上图,设2与3之间的权值异或为0,那么这条边对答案的贡献为2*3(2可以扩展到2、4两个点,3可以扩展到3、5、6这三个点)。

桶里面加上的不是个数了,而是对于 当前重心 到 x 的 这一路径 ,x可以扩展到几个点。如果知道对于每条路径上面的端点可以扩展到多少点这题就可做了。

预处理:让a作为根,遍历一遍,统计每个点子树的大小siz[x],与每个点的上一个点init_pre[x]。
现在要求rt->…..->pre->x这一路径x可以扩展多少:
1、init_pre[x]==pre,那么x可以扩展的点就是siz[x].
2、init_pre[x]!=pre,那么x可以扩展的点是n-siz[pre]
可以o(1)求出来。

解题思路:

树上路径计数问题,想当然就是点分治了,但是这道题并没有那么简单。

以这个图为例:

首先是路径两端点数的计算问题:

假设我们现在点分治的过程中以3为重心,3-4和2-3都为2,因此2-4这个路径的异或和为0,我们期望用2号节点及其上面的点的数量乘以4号节点以及其下面的点的数量来得到2-4这条路径的贡献。但是点分治里的sz数组是不固定且不正确的,会随着重心的不同发生变化,所以我们用最初第一次以1为根获得的siz进行处理,在第一次getrt的过程中记录每个节点的前驱节点,如果在这次遍历过程中前驱节点与第一次getrt过程中的前驱节点相同的话,代表当前方向与第一次getrt时的方向相同,就可以直接用第一次getrt处理到的siz,否则说明方向相反,就要用总的点数 n 减去当前节点前驱节点的siz获得结果,比如2号节点及其上方的点数就可以用总点数5减去2的前驱3及其3下方点的数量得到。

第二个问题是当前重心到某一点的异或和直接为零的情况,需要特殊处理:

假设图中4-5的权值为2,我们依然假设3为当前重心,这样3-5的异或和为0,我们要单独去判断一下重心另一侧的点的数量再相乘(因为这一部分再遍历map的时候是没有处理的)

#include<bits/stdc++.h>
#include<>
using namespace std;
#define ll long long
#define maxn 100100
#define inf 0x3f3f3f3f
#define mod 1000000007
struct node
{
    ll v,w,to;
} edge[maxn*2];
struct data
{
    ll w,sum;
}temp[maxn];
bool vis[maxn];
int head[maxn],cnt,n,rt,a,tot;
ll sum,sz[maxn],maxx[maxn],ans,b,sz_rt;
ll dis[maxn],siz[maxn],init_pre[maxn];
unordered_map<ll,ll>mp;
void init()
{
    memset(head,-1,sizeof(head));
    memset(vis,0,sizeof(vis));
    cnt=ans=0;
}
void add(int u,int v,ll w)
{
    edge[cnt]={v,w,head[u]};
    head[u]=cnt++;
    edge[cnt]={u,w,head[v]};
    head[v]=cnt++;
}
void dfs(int x,int pre)
{
    siz[x]=1;
    init_pre[x]=pre;
    for(int i=head[x];i!=-1;i=edge[i].to)
    {
        int v=edge[i].v;
        if(v!=pre)
        {
            dfs(v,x);
            siz[x]+=siz[v];
        }
    }
}
void getrt(int x,int pre)
{
    sz[x]=1;
    maxx[x]=0;
    for(int i=head[x];i!=-1;i=edge[i].to)
    {
        int v=edge[i].v;
        if(v!=pre&&!vis[v])
        {
            getrt(v,x);
            sz[x]+=sz[v];
            maxx[x]=max(maxx[x],sz[v]);
        }
    }
    maxx[x]=max(maxx[x],sum-sz[x]);
    if(maxx[x]<maxx[rt])rt=x;
}
void getdis(int x,int pre)
{
    if(pre==init_pre[x])
    {
        ans=(ans+siz[x]*mp[dis[x]])%mod;
        temp[++tot]= {dis[x],siz[x]};
        if(dis[x]==0)
        ans=(ans+siz[x]*sz_rt)%mod;
    }
    else
    {
        ans=(ans+1ll*(n-siz[pre])*mp[dis[x]])%mod;
        temp[++tot]= {dis[x],n-siz[pre]};
        if(dis[x]==0)
        ans=(ans+1ll*(n-siz[pre])*sz_rt)%mod;
    }
    for(int i=head[x];i!=-1;i=edge[i].to)
    {
        int v=edge[i].v;
        ll w=edge[i].w;
        if(v==pre||vis[v])continue;
        dis[v]=dis[x]^w;
        getdis(v,x);
    }
}
void cal(int x)
{
    dis[x]=0;
    for(int i=head[x];i!=-1;i=edge[i].to)
    {
        int v=edge[i].v;
        ll w=edge[i].w;
        if(vis[v])continue;
        if(init_pre[v]==x)sz_rt=n-siz[v];
        else sz_rt=siz[rt];
        dis[v]=w;
        tot=0;
        getdis(v,x);
        for(int j=1;j<=tot;j++)
        {
            mp[temp[j].w]+=temp[j].sum;
            mp[temp[j].w]%=mod;
        }
    }
    mp.clear();
}
void solve(int x)
{
    vis[x]=1;
    cal(x);
    for(int i=head[x];i!=-1;i=edge[i].to)
    {
        int v=edge[i].v;
        if(vis[v])continue;
        maxx[rt=0]=inf;
        sum=sz[v];
        getrt(v,0);
        solve(rt);
    }
}
int main()
{
    init();
    scanf("%d",&n);
    for(int i=2;i<=n;i++)
    {
        scanf("%d%lld",&a,&b);
        add(i,a,b);
    }
    dfs(1,0);
    maxx[rt=0]=inf;
    sum=n;
    getrt(1,0);
    solve(rt);
    printf("%lld\n",ans%mod);
    return 0;
}

  

posted on 2019-05-28 11:37  自由缚  阅读(197)  评论(0编辑  收藏  举报