Tree

题目链接

https://ac.nowcoder.com/acm/contest/6226/C

tag

换根DP

solution

\(d[u]\)表示\(u\)的子树对\(u\)的贡献,根据乘法原理可以得到, \(d[u] = d[u] * (1 + d[v])\)\(v\)\(u\)的儿子,第一次从\(root\)节点\(dfs\)可以求出每个节点\(d[i]\),同时我们可以得到\(ans[root] = d[root]\),然后我们第二次\(dfs\)换根统计其他节点的答案,\(g[u]\)表示除去\(u\)\(u\)的子树,即\(u\)的父亲节点对答案的贡献,对于节点\(u, v\), \(v\)\(u\)的儿子,不难得到\(ans[v] = d[v] * (1 + g[v])\),当\((1 + d[v]) % mod != 0\) 时,\(g[v] = ans[u] /(1 + d[v])\) , 在取模意义下\((1+d[v])\)为0时,我们将\(u\)的出边分为三类,\(u\)的父亲节点, \(v\),剩余\(u\)的儿子节点\(son\),我们可以暴力更新\(g[v]\),首先我们已经得到了\(u\)的父亲节点对答案的贡献即\(g[u]\),然后我们需要计算\(u\)除去\(v\)以外的儿子节点对答案的贡献,即\(d[son]\),根据乘法原理, \(g[v] *= (1 + g[u]),g[v] *= (1 + d[son])\),然后根据\(ans[v] = d[v] * (1 + g[v])\),换根便可以得到所有点的答案

code

//created by pyoxiao on 2020/07/07
#include<bits/stdc++.h>
#define LL long long
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define CL(a, b) memset(a, b, sizeof(a))
using namespace std;
const int mod = 1e9 + 7;
LL fpow(LL a, LL b, LL p = mod){LL ans = 1; a %= p; while(b) {if(b & 1) ans = ans * a % p; b >>= 1; a = a * a % p;} return ans;}
LL gcd(LL a, LL b){return b == 0 ? a : gcd(b, a % b);}
LL inv(LL x) {return fpow(x, mod - 2); }
const int N = 1e6 + 7;
vector<int> ver[N];
int n;
LL d[N], f[N], g[N], ffa[N];
void dfs(int u, int fa) {
    d[u] = 1; ffa[u] = fa; 
    for(auto to : ver[u]) {
        if(to == fa) continue;
        dfs(to, u);
        d[u] = d[u] * (1 + d[to]) % mod;
    }
}
void dfs2(int u, int fa) {
    if(u != 1) {
        if((d[u] + 1) % mod) {
            g[u] = f[fa] * inv(d[u] + 1) % mod;
        } else {
            LL res = 1 + g[fa];
            for(auto to : ver[fa]) {
                if(to == u || to == ffa[fa]) continue;
                res *= (1 + d[to]);
                res %= mod;
            }
            g[u] = res;
        }
        f[u] = d[u] * (1 + g[u]) % mod;
    }
    for(auto to : ver[u]) {
        if(to == fa) continue;
        dfs2(to, u);
    }
}
void solve() {
    scanf("%d", &n);
    for(int i = 2; i <= n; i ++) {
        int u, v; 
        scanf("%d %d", &u, &v);
        ver[u].pb(v);
        ver[v].pb(u);
    }
    dfs(1, 0);
    f[1] = d[1];
    dfs2(1, 0);
    for(int i = 1; i <= n; i ++) printf("%lld\n", (f[i] % mod + mod) % mod);
}
int main() {
    int T = 1;
    // scanf("%d", &T);
    while(T --) 
        solve();
    return 0;
}
posted @ 2020-07-08 16:35  pyoxiao  阅读(180)  评论(0编辑  收藏  举报