Tree
题目描述
一棵树树有 n 个节点,n − 1 条边。
树上的节点有两种:黑,白节点。
Tyk想断掉一些边把树分成很多部分。
他想要保证每个部分里面有且仅有一个黑节点。
请问他一共有多少种的方案?
输入
第一行一个数字 n,表示树的节点个数。
第二行一共 n − 1 个数字 p0, p1, p2, p3, ..., pn−2,pi 表示第 i + 1 个节点和 pi 节
点之间有一条边。注意,点的编号是 0 到 n − 1。
第三行一共 n 个数字 x0, x1, x2, x3, ..., xn−1。如果 xi 是 1,表示 i 号节点是黑的
如果 xi 是 0,表示 i 号节点是白的。
输出
输出一个数字,表示总方案数。答案对 10^9 + 7 取模。
样例输入
3 0 0 0 1 1 6
样例输出
2
提示
数据范围
对于 30% 的数据,1 ≤ n ≤ 10。
对于 60% 的数据,1 ≤ n ≤ 100。
对于 80% 的数据,1 ≤ n ≤ 1000。
对于 100% 的数据,1 ≤ n ≤ 10^5。
对于所有数据点,都有 0 ≤ pi ≤ n − 1,xi = 0 或 xi = 1。
特别地,60% 中、80% 中、100% 中各有一个点,树的形态是一条链。
设f[i][0]表示以i为根的子树中,i所在的部分中没有黑点。
f[i][1]表示以i为根的子树中,i所在的部分中有黑点。
显然可以$O(n^2)$的跑这个dp,我们发现f[i][1]的转移中只有一项不同。
我们预处理出前缀积和后缀积,就可以通过此题。
1 #include <bits/stdc++.h> 2 using namespace std; 3 #define M 100010 4 #define MOD 1000000007 5 struct Edge{ 6 int u, v, Next; 7 } G[M]; 8 int head[M], tot; 9 inline void add(int u, int v) { 10 G[++ tot] = (Edge){u, v, head[u]}; 11 head[u] = tot; 12 } 13 int f[M][2]; 14 int a[M]; 15 int b[M]; 16 int fr[M], to[M]; 17 inline void dfs(int x) { 18 for(int i = head[x]; i != -1; i = G[i].Next) { 19 dfs(G[i].v); 20 } 21 int tt = 0; 22 for(int i = head[x]; i != -1; i = G[i].Next) { 23 b[++ tt] = f[G[i].v][0] + f[G[i].v][1]; 24 if(b[tt] >= MOD) b[tt] -= MOD; 25 } 26 if(a[x] == 1) { 27 if(head[x] == -1) return (void)(f[x][0] = 1); 28 f[x][0] = 1; 29 for(int i = head[x]; i != -1; i = G[i].Next) { 30 f[x][0] = 1ll * f[x][0] * (f[G[i].v][0] + f[G[i].v][1]) % MOD; 31 } 32 /*for(int i = 1; i <= tt; ++ i) { 33 f[x][0] = 1ll * f[x][0] * b[i] % MOD; 34 }*/ 35 } 36 else { 37 if(head[x] == -1) return (void)(f[x][1] = 1); 38 f[x][1] = 1; 39 for(int i = head[x]; i != -1; i = G[i].Next) { 40 f[x][1] = 1ll * f[x][1] * (f[G[i].v][0] + f[G[i].v][1]) % MOD; 41 } 42 fr[0] = 1; to[tt + 1] = 1; 43 for(int i = 1; i <= tt; ++ i) { 44 fr[i] = 1ll * fr[i - 1] * b[i] % MOD; 45 } 46 for(int i = tt; i >= 1; -- i) { 47 to[i] = 1ll * to[i + 1] * b[i] % MOD; 48 } 49 f[x][0] = 0; 50 for(int i = 1, j = head[x]; i <= tt; j = G[j].Next, ++ i) { 51 f[x][0] += 1ll * f[G[j].v][0] * fr[i - 1] % MOD * to[i + 1] % MOD; 52 if(f[x][0] >= MOD) f[x][0] -= MOD; 53 } 54 } 55 } 56 int main() { 57 int n; 58 scanf("%d", &n); 59 memset(head, -1, sizeof(head)); 60 for(int i = 2; i <= n; ++ i) { 61 int p; scanf("%d", &p); 62 ++ p; 63 add(p, i); 64 } 65 for(int i = 1; i <= n; ++ i) { 66 scanf("%d", &a[i]); 67 } 68 dfs(1); 69 //printf("%d\n", f[13][0]); 70 printf("%d\n", f[1][0]); 71 }