Codeforces 1156D 0-1-Tree(树形dp)
传送:http://codeforces.com/contest/1156/problem/D
题意:有一棵$n$($n\leq200000$)个结点的树,$n-1$条边,每条边有一个值$(0,1)$,对于从$x$到$y$的唯一路径不能从0边到1边,问有多少点对符合要求。
分析:
考虑这样一个dp方程,$dp[i][0/1]$,$dp[i][0]$代表从结点$i$出发的权值为0的边,可以到达点的个数,同理:$dp[i][1]$代表从结点$i$出发的权值为1的边,可以到达点的个数。
那么就是说,我做两边树的遍历:
第一遍可以先处理出儿子继承父亲的答案;
第二遍处理出父亲“继承”儿子的答案(同时需要去除掉本身儿子继承父亲的答案)。
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int maxn=2e5+10; 5 struct node{ 6 int to,w,nxt; 7 }e[maxn*2]; 8 int head[maxn],tot; 9 ll ans; 10 ll dp[maxn][2]; 11 void add(int x,int y,int w){ 12 e[tot]={y,w,head[x]}; 13 head[x]=tot++; 14 } 15 void dfs(int x,int fa){ 16 for (int i=head[x];i!=-1;i=e[i].nxt){ 17 if (e[i].to==fa) continue; 18 dfs(e[i].to,x); 19 if (e[i].w==0) dp[x][0]+=dp[e[i].to][0]; 20 else dp[x][1]+=dp[e[i].to][0]+dp[e[i].to][1]; 21 } 22 } 23 void dfs2(int x,int fa){ 24 for (int i=head[x];i!=-1;i=e[i].nxt){ 25 if (e[i].to==fa) continue; 26 if (e[i].w==0) dp[e[i].to][0]+=(dp[x][0]-dp[e[i].to][0]); 27 else dp[e[i].to][1]+=(dp[x][0]-dp[e[i].to][0])+(dp[x][1]-dp[e[i].to][1]); 28 dfs2(e[i].to,x); 29 } 30 } 31 int main(){ 32 int n,x,y,z; scanf("%d",&n); 33 tot=0; 34 for (int i=1;i<=n;i++) head[i]=-1; 35 for (int i=0;i<n-1;i++){ 36 scanf("%d%d%d",&x,&y,&z); 37 add(x,y,z); 38 add(y,x,z); 39 dp[x][z]++; 40 } 41 ans=0; 42 dfs(1,0); 43 dfs2(1,0); 44 for (int i=1;i<=n;i++){ 45 cout << dp[i][0] << " " << dp[i][1] << endl; 46 ans+=(dp[i][0]+dp[i][1]); 47 } 48 printf("%lld\n",ans); 49 return 0; 50 }
来自学妹的并查集做法:
维护两个块,一个全为1,一个全为0。全为1或者全为0的块内答案为num*(num-1);
同时如果一个点可以连接起全0块或全1块,那么答案为(num1-1)*(num2-1)。
1 #include<bits/stdc++.h> 2 #define ll long long 3 using namespace std; 4 const int maxn=2e5+10; 5 int pre[2][maxn],num[2][maxn]; 6 int n,a,b,c; 7 int find(int mark,int x) 8 { 9 return x==pre[mark][x]?x:pre[mark][x]=find(mark,pre[mark][x]); 10 } 11 void merge(int mark,int x,int y) 12 { 13 int fx=find(mark,x),fy=find(mark,y); 14 if (fx!=fy) 15 { 16 pre[mark][fx]=fy; 17 num[mark][fy]+=num[mark][fx]; 18 } 19 } 20 int main() 21 { 22 scanf("%d",&n); 23 for (int i=0;i<=n;i++) pre[0][i]=pre[1][i]=i,num[0][i]=num[1][i]=1; 24 for (int i=1;i<n;i++) 25 { 26 scanf("%d%d%d",&a,&b,&c); 27 merge(c,a,b); 28 } 29 ll ans=0; 30 for (int i=1;i<=n;i++) 31 { 32 if (pre[0][i]==i) ans+=1ll*num[0][i]*(num[0][i]-1); 33 if (pre[1][i]==i) ans+=1ll*num[1][i]*(num[1][i]-1); 34 int xx=find(0,i),yy=find(1,i); 35 ans+=1ll*(num[0][xx]-1)*(num[1][yy]-1); 36 } 37 printf("%lld\n",ans); 38 return 0; 39 }