Codeforces 543D. Road Improvement (树dp + 乘法逆元)

题目链接:http://codeforces.com/contest/543/problem/D

给你一棵树,初始所有的边都是坏的,要你修复若干边。指定一个root,所有的点到root最多只有一个坏边。以每个点为root,问分别有多少种方案数。

dp[i]表示以i为子树的root的情况数,不考虑父节点,考虑子节点。   dp[i] = dp[i] * (dp[i->son] + 1)

up[i]表示以i为子树的root的情况数(倒着的),考虑父节点,不考虑子节点。  这里需要逆元。 注意(a/b)%mod中b%mod=0是错误的,所以要特殊判断。

 1 //#pragma comment(linker, "/STACK:102400000, 102400000")
 2 #include <algorithm>
 3 #include <iostream>
 4 #include <cstdlib>
 5 #include <cstring>
 6 #include <cstdio>
 7 #include <vector>
 8 #include <cmath>
 9 #include <ctime>
10 #include <list>
11 #include <set>
12 #include <map>
13 using namespace std;
14 typedef long long LL;
15 typedef pair <int, int> P;
16 const int N = 2e5 + 5;
17 LL dp[N], mod = 1e9 + 7, up[N];
18 vector <int> edge[N];
19 int cnt[N]; //子树(dp[i->son] + 1)%mod != 0的节点数
20 LL fuck[N]; //子树(dp[i-son] + 1)%mod != 0的方案数相乘
21 
22 LL fpow(LL a, LL n) {
23     LL res = 1;
24     while(n) {
25         if(n & 1)
26             res = res * a % mod;
27         a = a * a % mod;
28         n >>= 1;
29     }
30     return res;
31 }
32 
33 void dfs1(int u, int p) {
34     dp[u] = 1;
35     fuck[u] = 1;
36     for(int i = 0; i < edge[u].size(); ++i) {
37         int v = edge[u][i];
38         if(v == p)
39             continue;
40         dfs1(v, u);
41         if(dp[v] + 1 == mod)
42             cnt[u]++;
43         else
44             fuck[u] = (1 + dp[v]) % mod * fuck[u] % mod;
45         dp[u] = (1 + dp[v]) % mod * dp[u] % mod;
46     }
47 }
48 
49 void dfs2(int u, int p) {
50     for(int i = 0; i < edge[u].size(); ++i) {
51         int v = edge[u][i];
52         if(v == p)
53             continue;
54         //LL temp = dp[u] * fpow((dp[v] + 1) % mod, mod - 2) % mod; //error
55         LL temp = 0;
56         if(dp[v] + 1 == mod && up[u] && cnt[u] == 1) {   //特殊情况
57             temp = fuck[u];
58         } else {
59             temp = dp[u] * fpow((dp[v] + 1) % mod, mod - 2) % mod;
60         }
61         up[v] = (up[u] * temp % mod + 1) % mod;
62         dfs2(v, u);
63     }
64 }
65 
66 int main()
67 {
68     int n, u;
69     scanf("%d", &n);
70     for(int i = 2; i <= n; ++i) {
71         scanf("%d", &u);
72         edge[i].push_back(u);
73         edge[u].push_back(i);
74     }
75     dfs1(1, -1);
76     up[1] = 1;
77     dfs2(1, -1);
78     for(int i = 1; i <= n; ++i) {
79         printf("%lld%c", dp[i]*up[i]%mod, i == n ? '\n': ' ');
80     }
81     return 0;
82 }

 

posted @ 2016-09-27 16:06  Recoder  阅读(273)  评论(0编辑  收藏  举报