拓扑序计数

拓扑序计数

  时间限制: 1 s

 空间限制: 128000 KB
题目描述 Description

求一颗有根树/树形图的拓扑序个数.

输入描述 Input Description

第一行一个n和一个素数P,表示这颗树有n个节点,要求拓扑序个数modP之后的结果.

接下来n-1行,每行有两个数字x和y,表示x是y的父亲.

保证1<=n<=1500000, n<P<2^31,P为质数.

输出描述 Output Description

一行一个数字,为该树形图拓扑序个数modP的结果.

样例输入 Sample Input

样例1
4 100000007
1 2
1 3
2 4

样例2
6 100000007
1 2
2 3
1 4
3 5
5 6

样例输出 Sample Output

样例1
3

样例2
5

数据范围及提示 Data Size & Hint

每个非根节点的儿子个数平均较多,数据量较大,建议c/c++选手才用scanf的读入方式

题目链接:http://codevs.cn/problem/1304/


树形DP,其实更像是树上计数。

题目分析:题目要求的是一棵树上的拓扑序,树上的拓扑序肯定比图上的拓扑序有更好的性质。我们大致可以推出二叉树上子树合并到根的公式,然后再推广。

状态设计:设ans1[x]为以结点x为根的子树的拓扑序个数,ans2[x]为以x为根的子树的节点个数。

状态转移:当树为二叉树时,我们很容易发现ans2[u]=ans2[v1]+ans2[v2]+1,ans1[u]=C(ans2[v1]+ans2[v2],ans2[v1])*ans1[v1]*ans1[v2]。当树不为二叉树的时候,我们可以先把两个子树合并为一个子树,再把这个子树依次和其他子树合并。
#include<bits/stdc++.h>
#define N 3000055
using namespace std;
long long mod;
long long jc[N],ny[N];

long long qmod(long long x,long long y)
{
    long long ans=1;
    while(y)
    {
        if(y%2==1)ans=(ans*x)%mod;
        y/=2;
        x=(x*x)%mod;
    }
    return ans;
}

void init()
{
    jc[1]=1;
    for(int i=2;i<N;i++)jc[i]=(jc[i-1]*i)%mod;
    
    ny[N-1]=qmod(jc[N-1],mod-2);
    for(int i=N-2;i>=1;i--)ny[i]=(ny[i+1]*(i+1))%mod;
}

long long C(long long m,long long n)
{
    if(n==0||m==n)return 1;
    return jc[m]*ny[n]%mod*ny[m-n]%mod;
}

struct ss
{
    int v,next;
};
ss edg[N/2];
int head[N/2],now_edge=0;
int d[N/2]={0};

void addedge(int u,int v)
{
    edg[now_edge]=(ss){v,head[u]};
    head[u]=now_edge++;
}

long long ans1[N/2]={0},ans2[N/2]={0};

void dfs(int x)
{
    ans1[x]=1;
    
    for(int i=head[x];i!=-1;i=edg[i].next)
    {
        int v=edg[i].v;
        dfs(v);
        ans2[x]+=ans2[v];
        ans1[x]=C(ans2[x],ans2[v])*ans1[x]%mod*ans1[v]%mod;
    }
    ans2[x]++;
}

long long read()
{
    long long ans=0;
    char ch=getchar();
    
    while(!(ch>='0'&&ch<='9'))ch=getchar();
    while(ch>='0'&&ch<='9')
    {
        ans*=10;
        ans+=ch-'0';
        ch=getchar();
    }
    return ans;
}
int main()
{
    int n;
    //scanf("%d %lld",&n,&mod);
    n=(int)read();mod=read();
    
    init();
    memset(head,-1,sizeof(head));
    for(int i=1;i<n;i++)
    {
        int x,y;
    //    scanf("%d %d",&x,&y);
        x=read();y=read();
        addedge(x,y);
        d[y]++;
    }
    
    for(int i=1;i<=n;i++)
    if(!d[i])
    {
        dfs(i);
        printf("%lld\n",ans1[i]);
        break;
    }
    
    return 0;
}
View Code

 

posted @ 2018-10-04 00:21  1371767389  阅读(532)  评论(0编辑  收藏  举报