BNUOJ 52305 Around the World 树形dp

题目链接:

https://www.bnuoj.com/v3/problem_show.php?pid=52305

Around the World

Time Limit: 20000ms
Memory Limit: 1048576KB

题意

给你一颗树,相邻两点间有2c条边,问以1为起点和终点的欧拉回路有多少种。

题解

树形dp
兄弟之间考虑可重集的排序,父子之间则考虑下插板法。
dp[u]表示以u为根的子树能跑的所有欧拉回路。

#include<map>
#include<set>
#include<cmath>
#include<queue>
#include<stack>
#include<ctime>
#include<vector>
#include<cstdio>
#include<string>
#include<bitset>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<functional>
using namespace std;
#define X first
#define Y second
#define mkp make_pair
#define lson (o<<1)
#define rson ((o<<1)|1)
#define mid (l+(r-l)/2)
#define sz() size()
#define pb(v) push_back(v)
#define all(o) (o).begin(),(o).end()
#define clr(a,v) memset(a,v,sizeof(a))
#define bug(a) cout<<#a<<" = "<<a<<endl
#define rep(i,a,b) for(int i=a;i<(b);i++)
#define scf scanf
#define prf printf

typedef long long LL;
typedef unsigned long long ULL;
typedef vector<int> VI;
typedef pair<int,int> PII;
typedef vector<pair<int,int> > VPII;

const int INF=0x3f3f3f3f;
const LL INFL=0x3f3f3f3f3f3f3f3fLL;
const double eps=1e-8;
const double PI = acos(-1.0);

//start----------------------------------------------------------------------

const int maxn=1e5+10;
const int maxm=2e6+100;
const int mod=1e9+7;

VPII G[maxn];

int n;

LL dp[maxn];
int cntv[maxn];
LL inv[maxm],invfac[maxm],fac[maxm];

LL get_C(int n,int m){
    return fac[n]*invfac[m]%mod*invfac[n-m]%mod;
}

void dfs(int u,int fa){
    LL &res=dp[u]=1;
    cntv[u]=0;
    rep(i,0,G[u].sz()){
        int v=G[u][i].X,c=G[u][i].Y;
        if(v==fa) continue;
        dfs(v,u);
        cntv[u]+=c;
        //儿子的兄弟间的可重集排列
        res=res*invfac[c]%mod;
        //上下要插板
        res=res*get_C(cntv[v]+c-1,c-1)%mod;
        //内部阶乘
        res=res*fac[2*c]%mod;
        //分步乘法
        res=res*dp[v]%mod;
    }
    //儿子的兄弟间的可重集排列
    res=res*fac[cntv[u]]%mod;
}

void pre(){
    fac[0]=fac[1]=1;
    invfac[0]=invfac[1]=1;
    inv[1]=1;
    rep(i,2,maxm){
        fac[i]=fac[i-1]*i%mod;
        inv[i]=(mod-mod/i)*inv[mod%i]%mod;
        invfac[i]=invfac[i-1]*inv[i]%mod;
    }
}

int main() {
    pre();
    scf("%d",&n);
    for(int i=0;i<n-1;i++){
        int u,v,c;
        scf("%d%d%d",&u,&v,&c);
        G[u].pb(mkp(v,c));
        G[v].pb(mkp(u,c));
    }
    dfs(1,-1);
    prf("%lld\n",dp[1]);
    return 0;
}

//end-----------------------------------------------------------------------
posted @ 2016-10-04 11:53  fenicnn  阅读(195)  评论(0编辑  收藏  举报