题意
后两个求和符号代表的是有多少异或值为0的路径。
前两个符号代表有多少路径包含异或值为0的路径。
即每个权值为0的路径对答案的贡献为 有多少路径包含当前路径 ,所有的贡献加起来就是答案。
思路
点分治,权值太大,桶只能用map了。
x与y之间的路径(x与y间权值异或为0)对答案的贡献是
(x可以扩展的点数)**(y可以扩展的点数)。
如上图,设2与3之间的权值异或为0,那么这条边对答案的贡献为2*3(2可以扩展到2、4两个点,3可以扩展到3、5、6这三个点)。
桶里面加上的不是个数了,而是对于 当前重心 到 x 的 这一路径 ,x可以扩展到几个点。如果知道对于每条路径上面的端点可以扩展到多少点这题就可做了。
预处理:让a作为根,遍历一遍,统计每个点子树的大小siz[x],与每个点的上一个点init_pre[x]。
现在要求rt->…..->pre->x这一路径x可以扩展多少:
1、init_pre[x]==pre,那么x可以扩展的点就是siz[x].
2、init_pre[x]!=pre,那么x可以扩展的点是n-siz[pre]
可以o(1)求出来。
解题思路:
树上路径计数问题,想当然就是点分治了,但是这道题并没有那么简单。
以这个图为例:
首先是路径两端点数的计算问题:
假设我们现在点分治的过程中以3为重心,3-4和2-3都为2,因此2-4这个路径的异或和为0,我们期望用2号节点及其上面的点的数量乘以4号节点以及其下面的点的数量来得到2-4这条路径的贡献。但是点分治里的sz数组是不固定且不正确的,会随着重心的不同发生变化,所以我们用最初第一次以1为根获得的siz进行处理,在第一次getrt的过程中记录每个节点的前驱节点,如果在这次遍历过程中前驱节点与第一次getrt过程中的前驱节点相同的话,代表当前方向与第一次getrt时的方向相同,就可以直接用第一次getrt处理到的siz,否则说明方向相反,就要用总的点数 n 减去当前节点前驱节点的siz获得结果,比如2号节点及其上方的点数就可以用总点数5减去2的前驱3及其3下方点的数量得到。
第二个问题是当前重心到某一点的异或和直接为零的情况,需要特殊处理:
假设图中4-5的权值为2,我们依然假设3为当前重心,这样3-5的异或和为0,我们要单独去判断一下重心另一侧的点的数量再相乘(因为这一部分再遍历map的时候是没有处理的)
#include<bits/stdc++.h> #include<> using namespace std; #define ll long long #define maxn 100100 #define inf 0x3f3f3f3f #define mod 1000000007 struct node { ll v,w,to; } edge[maxn*2]; struct data { ll w,sum; }temp[maxn]; bool vis[maxn]; int head[maxn],cnt,n,rt,a,tot; ll sum,sz[maxn],maxx[maxn],ans,b,sz_rt; ll dis[maxn],siz[maxn],init_pre[maxn]; unordered_map<ll,ll>mp; void init() { memset(head,-1,sizeof(head)); memset(vis,0,sizeof(vis)); cnt=ans=0; } void add(int u,int v,ll w) { edge[cnt]={v,w,head[u]}; head[u]=cnt++; edge[cnt]={u,w,head[v]}; head[v]=cnt++; } void dfs(int x,int pre) { siz[x]=1; init_pre[x]=pre; for(int i=head[x];i!=-1;i=edge[i].to) { int v=edge[i].v; if(v!=pre) { dfs(v,x); siz[x]+=siz[v]; } } } void getrt(int x,int pre) { sz[x]=1; maxx[x]=0; for(int i=head[x];i!=-1;i=edge[i].to) { int v=edge[i].v; if(v!=pre&&!vis[v]) { getrt(v,x); sz[x]+=sz[v]; maxx[x]=max(maxx[x],sz[v]); } } maxx[x]=max(maxx[x],sum-sz[x]); if(maxx[x]<maxx[rt])rt=x; } void getdis(int x,int pre) { if(pre==init_pre[x]) { ans=(ans+siz[x]*mp[dis[x]])%mod; temp[++tot]= {dis[x],siz[x]}; if(dis[x]==0) ans=(ans+siz[x]*sz_rt)%mod; } else { ans=(ans+1ll*(n-siz[pre])*mp[dis[x]])%mod; temp[++tot]= {dis[x],n-siz[pre]}; if(dis[x]==0) ans=(ans+1ll*(n-siz[pre])*sz_rt)%mod; } for(int i=head[x];i!=-1;i=edge[i].to) { int v=edge[i].v; ll w=edge[i].w; if(v==pre||vis[v])continue; dis[v]=dis[x]^w; getdis(v,x); } } void cal(int x) { dis[x]=0; for(int i=head[x];i!=-1;i=edge[i].to) { int v=edge[i].v; ll w=edge[i].w; if(vis[v])continue; if(init_pre[v]==x)sz_rt=n-siz[v]; else sz_rt=siz[rt]; dis[v]=w; tot=0; getdis(v,x); for(int j=1;j<=tot;j++) { mp[temp[j].w]+=temp[j].sum; mp[temp[j].w]%=mod; } } mp.clear(); } void solve(int x) { vis[x]=1; cal(x); for(int i=head[x];i!=-1;i=edge[i].to) { int v=edge[i].v; if(vis[v])continue; maxx[rt=0]=inf; sum=sz[v]; getrt(v,0); solve(rt); } } int main() { init(); scanf("%d",&n); for(int i=2;i<=n;i++) { scanf("%d%lld",&a,&b); add(i,a,b); } dfs(1,0); maxx[rt=0]=inf; sum=n; getrt(1,0); solve(rt); printf("%lld\n",ans%mod); return 0; }