牛客国庆集训派对Day3 B Tree(树形dp + 组合计数)

传送门:https://www.nowcoder.com/acm/contest/203/B

思路及参考:https://blog.csdn.net/u013534123/article/details/82934820

这篇blog写得非常详细,但是我不会他说的立flag法,就学了其他同学的做法,如果不能做除法,就直接计数。我想了比较久明白的写在注释里啦。

//#pragma GCC optimize(3)
//#pragma comment(linker, "/STACK:102400000,102400000")  //c++
#include <algorithm>
#include  <iterator>
#include  <iostream>
#include   <cstring>
#include   <iomanip>
#include   <cstdlib>
#include    <cstdio>
#include    <string>
#include    <vector>
#include    <bitset>
#include    <cctype>
#include     <queue>
#include     <cmath>
#include      <list>
#include       <map>
#include       <set>
using namespace std;

#define lson (l , mid , rt << 1)
#define rson (mid + 1 , r , rt << 1 | 1)
#define debug(x) cerr << #x << " = " << x << "\n";
#define pb push_back
#define pq priority_queue



typedef long long ll;
typedef unsigned long long ull;

typedef pair<ll ,ll > pll;
typedef pair<int ,int > pii;
typedef pair<int ,pii> p3;
//priority_queue<int> q;//这是一个大根堆q
//priority_queue<int,vector<int>,greater<int> >q;//这是一个小根堆q
#define fi first
#define se second
//#define endl '\n'

#define OKC ios::sync_with_stdio(false);cin.tie(0)
#define FT(A,B,C) for(int A=B;A <= C;++A)  //用来压行
#define REP(i , j , k)  for(int i = j ; i <  k ; ++i)
//priority_queue<int ,vector<int>, greater<int> >que;

const ll mos = 0x7FFFFFFFLL;  //2147483647
const ll nmos = 0x80000000LL;  //-2147483648
const int mod = 1e9+7;
const int inf = 0x3f3f3f3f;
const ll inff = 0x3f3f3f3f3f3f3f3fLL; //18
const double PI=acos(-1.0);

template<typename T>
inline T read(T&x){
    x=0;int f=0;char ch=getchar();
    while (ch<'0'||ch>'9') f|=(ch=='-'),ch=getchar();
    while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
    return x=f?-x:x;
}
// #define _DEBUG;         //*//
#ifdef _DEBUG
freopen("input", "r", stdin);
// freopen("output.txt", "w", stdout);
#endif
/*-----------------------show time----------------------*/
            const int maxn = 1e6+9;
            vector<int>mp[maxn];
            ll dp[maxn],ans[maxn],pa[maxn],la[maxn];
            ll ksm(ll a, ll b){
                ll res = 1;
                while(b > 0){
                    if(b&1) res = res * a % mod;
                    a = a * a % mod;
                    b >>= 1;
                }
                return res;
            }
            void dfs1(int u,int fa){
                dp[u] = 1;
                pa[u] = fa;
                for(int i=0; i<mp[u].size(); i++){
                    int v = mp[u][i];
                    if(v == fa)continue;
                    dfs1(v,u);
                    dp[u] = 1ll*dp[u] * (dp[v]+1) % mod;
                }
            }
            void dfs2(int u,int fa){
                if(fa == -1)ans[u] = dp[u];
                else {
                    if((dp[u]+1)%mod == 0){     //这里是暴力计算的。
                        ans[u] = la[fa] + 1;    //la数组记录,通过fa到u但不包括fa及子树的点的影响。
                        for(int i=0; i<mp[fa].size(); i++){
                            int v = mp[fa][i];
                            if(v == u || v == pa[fa])continue;
                            ans[u] = 1ll*ans[u] * (dp[v] + 1) % mod;
                        }
                    }
                    else ans[u] = 1ll*ans[fa] * ksm(dp[u]+1,mod-2) % mod;

                    la[u] = ans[u];
                    ans[u] = 1ll*(ans[u]+1) * dp[u] % mod;
                }

                for(int i=0; i<mp[u].size(); i++){
                    int v = mp[u][i];
                    if(fa == v)continue;
                    dfs2(v, u);
                }
 
            }
int main(){
            int n;  scanf("%d", &n);
            for(int i=1; i<n; i++){
                int u,v;
                scanf("%d%d", &u, &v);
                mp[u].pb(v);
                mp[v].pb(u);
            }
            dfs1(1,-1);
            dfs2(1,-1);
            for(int i=1; i<=n; i++){
                printf("%lld\n", ans[i]);
            }
            return 0;
}
View Code

 

posted @ 2018-10-07 23:11  ckxkexing  阅读(187)  评论(0编辑  收藏  举报