#1479 : 三等分(树形DP)
http://hihocoder.com/problemset/problem/1479
#1479 : 三等分
时间限制:10000ms
单点时限:1000ms
内存限制:256MB
描述
小Hi最近参加了一场比赛,这场比赛中小Hi被要求将一棵树拆成3份,使得每一份中所有节点的权值和相等。
比赛结束后,小Hi发现虽然大家得到的树几乎一模一样,但是每个人的方法都有所不同。于是小Hi希望知道,对于一棵给定的有根树,在选取其中2个非根节点并将它们与它们的父亲节点分开后,所形成的三棵子树的节点权值之和能够两两相等的方案有多少种。
两种方案被看做不同的方案,当且仅当形成方案的2个节点不完全相同。
输入
每个输入文件包含多组输入,在输入的第一行为一个整数T,表示数据的组数。
每组输入的第一行为一个整数N,表示给出的这棵树的节点数。
接下来N行,依次描述结点1~N,其中第i行为两个整数Vi和Pi,分别描述这个节点的权值和其父亲节点的编号。
父亲节点编号为0的节点为这棵树的根节点。
对于30%的数据,满足3<=N<=100
对于100%的数据,满足3<=N<=100000, |Vi|<=100, T<=10
输出
对于每组输入,输出一行Ans,表示方案的数量。
- 样例输入
-
2 3 1 0 1 1 1 2 4 1 0 1 1 1 2 1 3
- 样例输出
-
1 0
参考博客:http://blog.csdn.net/viphong/article/details/61958631
需要两个dfs,第一个dfs从父节点开始递归遍历,求出以每个节点为根的子树的权值和。
第二个dfs就是开始统计个数: 若某一节点正好是总数的1/3,那么该节点很有可能和另一个节点符合题目要求,那么另一个节点就是另一个1/3节点,或者另一个节点是该节点祖先节点,这个祖先节点是总数的2/31 #include <iostream> 2 #include <cstdio> 3 #include <algorithm> 4 #include <cstring> 5 #include <vector> 6 using namespace std; 7 const int Max = 100000 + 10; 8 int v[Max], sum[Max]; 9 long long all, first, second, cnt, root; 10 vector<int> mp[Max]; 11 //统计每个子树的权值和 12 void dfs(int node, int fa) 13 { 14 sum[node] = v[node]; 15 for (int i = 0; i < (int)mp[node].size(); i++) 16 { 17 int son = mp[node][i]; 18 if (son != fa) 19 { 20 dfs(son, node); 21 sum[node] += sum[son]; 22 } 23 } 24 } 25 // 核心 26 void dfs2(int node, int fa) 27 { 28 //找到一个1/3节点 29 if (sum[node] == all) 30 cnt += first + second; 31 // 因为可能有负数,所有要继续往下递归 32 if (sum[node] == all * 2 && node != root) 33 second++; 34 35 for (int i = 0; i < (int)mp[node].size(); i++) 36 { 37 int son = mp[node][i]; 38 if (son != fa) 39 { 40 dfs2(son, node); 41 } 42 } 43 //每一个1/3的节点只会和另一个不同分支的1/3节点满足条件 44 //每一个1/3的节点只会和它祖先是2/3的满足条件 45 if (sum[node] == all) 46 first++; 47 //以node为根节点满足2/3,全都遍历完毕,所以再不存在以node为根与一个1/3节点满足条件,故删除该2/3节点 48 if (sum[node] == all * 2 && node != root) 49 second--; 50 } 51 int main() 52 { 53 int n, t, fa; 54 scanf("%d", &t); 55 while (t--) 56 { 57 scanf("%d", &n); 58 //清空 59 for (int i = 1; i <= n; i++) 60 mp[i].clear(); 61 memset(sum, 0, sizeof(sum)); 62 cnt = all = first = second = 0; 63 for (int i = 1; i <= n; i++) 64 { 65 scanf("%d%d", &v[i], &fa); 66 all += v[i]; 67 if (fa == 0) 68 root = i; 69 mp[fa].push_back(i); 70 } 71 if (all % 3) 72 { 73 printf("0\n"); 74 continue; 75 } 76 all /= 3; 77 dfs(root, 0); 78 dfs2(root, 0); 79 printf("%lld\n", cnt); 80 } 81 return 0; 82 }