Tree 换根dp

链接:https://ac.nowcoder.com/acm/contest/6226/C
来源:牛客网

题目描述

修修去年种下了一棵树,现在它已经有n个结点了。

修修非常擅长数数,他很快就数出了包含每个点的连通点集的数量。

澜澜也想知道答案,但他不会数数,于是他把问题交给了你。

输入描述:

第一行一个整数n (1≤ n ≤ 106),接下来n-1行每行两个整数ai,bi表示一条边 (1≤ ai,bi≤ n)。

输出描述:

输出n行,每行一个非负整数。第i行表示包含第i个点的连通点集的数量对109+7取模的结果。

示例1

输入

6
1 2
1 3
2 4
4 5
4 6

输出

12
15
7
16
9
9

题解

拿到题画了画, 明白了啥是联通点集, 即所有点(包含k点)和边连成一棵树, 即以k点为根的子树

那解样例, 自然发现读于根u的答案是通过子树统计的

      ans[u] = 1(只有自己)
      for v 子树
            ans[u] = ans[u] + ans[v] * ans[u](子树答案和当前根u的答案组合)
                   = ans[u] * (ans[v] + 1)

这样就只能求根的答案

但是我们发现, 对于子树的根, 就来自父节点的贡献没算!!

这不就是换根dp吗?

当子树的根做主根的时候, 那么原根的答案就为(把这棵子树当作是最后对原根贡献的子树)

//ans[fa] = ans[fa] * (ans[u] + 1) (把这棵子树当作是最后对原根贡献的子树)
//则这棵子树没贡献之前的 ans[fa]为 
res = ans[fa] / (ans[u] + 1)

那么, 就可以直接算子树根的答案了

ans[u] = ans[u] + res * ans[u]
       = ans[u] * (ans[fa] / (ans[u] + 1) + 1) 
       = ans[u] * (ans[fa] * power(ans[u] + 1, mod - 2) + 1)
//注意当ans[fa] = mod - 1, 就没有意义,至于为什么要提, 接下来说

这不就可以换根, 两个dfs ac了吗?

呵呵, 这题要取模的, 就存在ans[i] = 0的情况

那么谁让ans[i] == 0了呢?

ans[u] = ans[u] * (ans[v] + 1)

显然是子树 ans[v] == mod - 1, 我们记 cnt[i] 表示以 i为根的树含 ans[j] == mod - 1的数量

(且当子树为零树时, 我们就只标记, 就不算ans[u]了, 就不会出现除数为0的情况)

(且换根时, 由于ans[u] == mod - 1, ans[u]就没对ans[fa]贡献)

对于 cnt[i] > 0 的 ans[i] = 0, 但是我们在换根的时候就要考虑了,

父根cnt[fa] > 0, 则就没贡献, 但是如果 cnt[fa] = 1, 且ans[u] == mod - 1 呢?

没错我们要考虑父根只有一颗零树, 且当前子树就是这棵时的情况

那么就处理完了

具体看代码

代码

#include <bits/stdc++.h>
#define all(n) (n).begin(), (n).end()
#define se second
#define fi first
#define pb push_back
#define mp make_pair
#define sqr(n) (n)*(n)
#define rep(i,a,b) for(int i=a;i<=(b);++i)
#define per(i,a,b) for(int i=a;i>=(b);--i)
#define IO ios::sync_with_stdio(0);cin.tie(0)
using namespace std;
typedef long long ll;
typedef pair<int, int> PII;
typedef pair<ll, ll> PLL;
typedef vector<int> VI;
typedef double db;

const int N = 1e6 + 5;
const int mod = 1e9 + 7;

int n, cnt[N];
int h[N], to[N << 1], ne[N << 1], tot;
ll ans[N];

void add(int u, int v) {
    ne[++tot] = h[u]; h[u] = tot; to[tot] = v;
}

void calc(int u, ll res) {
    if (res == mod - 1) ++cnt[u];
    else ans[u] = ans[u] * (res % mod + 1) % mod;
}

void dfs1(int u, int f) {
    ans[u] = 1;
    for (int i = h[u]; i; i = ne[i]) {
        int y = to[i];
        if (y == f) continue;

        dfs1(y, u);
        calc(u, cnt[y] ? 0 : ans[y]);
    }
}

ll power(ll a, int b) {
    ll res = 1;
    for (; b; a = a * a % mod, b >>= 1)
        if (b & 1) res = res * a % mod;
    return res;
}

void dfs2(int u, int f) {
    if (f) {
        if (!cnt[u] && ans[u] == mod - 1) {
            if (cnt[f] == 1) calc(u, ans[f]);
        } else if (cnt[f] == 0) 
            calc(u, ans[f] * power(cnt[u] ? 1 : ans[u] % mod + 1, mod - 2));
    }

    for (int i = h[u]; i; i = ne[i]) {
        int y = to[i];
        if (y == f) continue;
        dfs2(y, u);
    }
}

int main() {
    IO;
    cin >> n;
    rep (i, 2, n) {
        int u, v; cin >> u >> v;
        add(u, v); add(v, u);
    }
    dfs1(1, 0); dfs2(1, 0);
    rep (i, 1, n) cout << (cnt[i] ? 0 : ans[i]) << '\n';
    return 0;
}
posted @ 2020-07-08 00:22  洛绫璃  阅读(187)  评论(0编辑  收藏  举报