BZOJ_3697_采药人的路径_点分治
BZOJ_3697_采药人的路径_点分治
Description
采药人的药田是一个树状结构,每条路径上都种植着同种药材。
采药人以自己对药材独到的见解,对每种药材进行了分类。大致分为两类,一种是阴性的,一种是阳性的。
采药人每天都要进行采药活动。他选择的路径是很有讲究的,他认为阴阳平衡是很重要的,所以他走的一定是两种药材数目相等的路径。采药工作是很辛苦的,所以他希望他选出的路径中有一个可以作为休息站的节点(不包括起点和终点),满足起点到休息站和休息站到终点的路径也是阴阳平衡的。他想知道他一共可以选择多少种不同的路径。
Input
第1行包含一个整数N。
接下来N-1行,每行包含三个整数a_i、b_i和t_i,表示这条路上药材的类型。
Output
输出符合采药人要求的路径数目。
Sample Input
7
1 2 0
3 1 1
2 4 0
5 2 0
6 3 1
5 7 1
1 2 0
3 1 1
2 4 0
5 2 0
6 3 1
5 7 1
Sample Output
1
HINT
对于100%的数据,N ≤ 100,000。
路径计数问题,很容易想到点分治。把0当成-1,那么路径长度为0的路径就是阴阳平衡的。
设f[i][0/1]表示到根的路径长度为i,且路径上没有/有阴阳平衡的路径的路径条数。
设g[i][0/1]表示到根的路径长度为-i,且路径上没有/有阴阳平衡的路径的路径条数。
对答案的贡献为f[i][0]*g[i][1]+f[i][1]*g[i][0]+f[i][1]*g[i][1]
然后发现向下找路径的时候长度一定是一个范围(因为边权为1或-1),我们记录这个范围就能求出这条路径上还有没有平衡的了。
其他细节比较多
代码:
#include <stdio.h> #include <string.h> #include <algorithm> using namespace std; #define N 200050 typedef long long ll; int head[N],to[N<<1],nxt[N<<1],val[N<<1],cnt; int n,fag[N],siz[N],tot,root,maxdeep; bool used[N]; ll ans,f[N][2],g[N][2]; inline void add(int u,int v,int w) { to[++cnt]=v; nxt[cnt]=head[u]; head[u]=cnt; val[cnt]=w; } void getroot(int x,int y) { fag[x]=0; siz[x]=1; int i; for(i=head[x];i;i=nxt[i]) { if(to[i]!=y&&!used[to[i]]) { getroot(to[i],x); siz[x]+=siz[to[i]]; fag[x]=max(fag[x],siz[to[i]]); } } fag[x]=max(fag[x],tot-siz[x]); if(fag[root]>fag[x]) root=x; } void calc(int x,int y,int now,int cnt) { int i; if(now==0) { if(cnt>=2) ans++; cnt++; } for(i=head[x];i;i=nxt[i]) { if(to[i]!=y&&!used[to[i]]) { calc(to[i],x,now+val[i],cnt); } } } void getdep(int x,int y,int now,int l,int r) { siz[x]=1; int i; if(now>=l&&now<=r) { if(now>=0) f[now][1]++; else g[-now][1]++; }else { if(now>=0) f[now][0]++; else g[-now][0]++; } l=min(l,now); r=max(r,now); maxdeep=max(maxdeep,max(-l,r)); for(i=head[x];i;i=nxt[i]) { if(to[i]!=y&&!used[to[i]]) { getdep(to[i],x,now+val[i],l,r); siz[x]+=siz[to[i]]; } } } void work(int x) { int i,j; used[x]=1; maxdeep=0; calc(x,0,0,0); getdep(x,0,0,1,-1); ans+=f[0][1]*(f[0][1]-1)/2; f[0][0]=f[0][1]=0; for(i=1;i<=maxdeep;i++) ans+=f[i][1]*g[i][1]+f[i][0]*g[i][1]+f[i][1]*g[i][0],f[i][0]=f[i][1]=g[i][0]=g[i][1]=0; for(i=head[x];i;i=nxt[i]) { if(!used[to[i]]) { maxdeep=0; getdep(to[i],0,val[i],0,0); ans-=f[0][1]*(f[0][1]-1)/2; f[0][0]=f[0][1]=0; for(j=1;j<=maxdeep;j++) ans-=f[j][1]*g[j][1]+f[j][0]*g[j][1]+f[j][1]*g[j][0],f[j][0]=f[j][1]=g[j][0]=g[j][1]=0; tot=siz[to[i]]; root=0; getroot(to[i],0); work(root); } } } int main() { scanf("%d",&n); int i,x,y,z; for(i=1;i<n;i++) { scanf("%d%d%d",&x,&y,&z); if(!z) z--; add(x,y,z); add(y,x,z); } tot=n; fag[0]=1<<30; getroot(1,0); work(root); printf("%lld\n",ans); }