【CF917D】Stranger Trees 树形DP+Prufer序列

【CF917D】Stranger Trees

题意:给你一棵n个点的树,对于k=1...n,问你有多少有标号的n个点的树,与给出的树有恰好k条边相同?

$n\le 100$

题解:我们先考虑容斥,求出和给出的树至少有k个点相同的树的数量。我们先选出原树中的k条边,然后剩下的边随便连。选出k条边后,原树被分成n-k个连通块,设其大小分别为$siz_1,siz_2...siz_{n-k}$。那么剩下的边随便连的方案数是多少呢?我们不妨把每个连通块看成一个点,答案变成n个点的完全图的生成树个数,根据Prufer序列知道这个答案是$n^{n-2}$。但是这里一个连通块的大小并不是1。对于一个大小为$siz$的连通块,如果它在Prufer序列中出现了j次,那么它对答案的贡献其实是$siz^{j+1}$(因为它的度数是j+1)。我们可以先把$\prod\limits_{i=1}^{n-k}siz_i$提出来,然后对于Prufer序列中的每个位置,如果它是第i个连通块,则贡献为$siz_i$,所以总的贡献为$\sum\limits_{i=1}^{n-k}siz_i=n$,那么答案就是$\prod\limits_{i=1}^{n-k}siz_i\times n^{n-k-2}$。

所以我们考虑树形DP,用f[x][a][b]表示在x的子树中,已经连了a条边,包含x的连通块大小为b的总贡献。最后容斥一发即可。

#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
typedef long long ll;
const ll P=1000000007;
int n,m,cnt;
int to[210],nxt[210],head[110],siz[110];
ll f[110][110][110],g[110][110],c[110][110],h[110],bt[110];
inline int rd()
{
	int ret=0,f=1;	char gc=getchar();
	while(gc<'0'||gc>'9')	{if(gc=='-')	f=-f;	gc=getchar();}
	while(gc>='0'&&gc<='9')	ret=ret*10+gc-'0',gc=getchar();
	return ret*f;
}
void dfs(int x,int fa)
{
	f[x][1][0]=1,siz[x]=1;
	for(int i=head[x],j,k,a,b,y;i!=-1;i=nxt[i])	if(to[i]!=fa)
	{
		y=to[i],dfs(to[i],x);
		memset(g,0,sizeof(g));
		for(j=1;j<=siz[x];j++)	for(k=1;k<=siz[y];k++)
		{
			for(a=0;a<siz[x];a++)	for(b=0;b<siz[y];b++)
			{
				g[j+k][a+b+1]=(g[j+k][a+b+1]+f[x][j][a]*f[y][k][b])%P;
				g[j][a+b]=(g[j][a+b]+f[x][j][a]*f[y][k][b]%P*k)%P;
			}
		}
		memcpy(f[x],g,sizeof(g));
		siz[x]+=siz[y];
	}
}
inline void add(int a,int b)
{
	to[cnt]=b,nxt[cnt]=head[a],head[a]=cnt++;
}
int main()
{
	n=rd();
	int i,j,a,b;
	memset(head,-1,sizeof(head));
	for(i=1;i<n;i++)	a=rd(),b=rd(),add(a,b),add(b,a);
	dfs(1,0);
	for(bt[0]=i=1;i<=n;i++)	bt[i]=bt[i-1]*n%P;
	for(i=0;i<=n;i++)	for(c[i][0]=j=1;j<=i;j++)	c[i][j]=(c[i-1][j-1]+c[i-1][j])%P;
	for(i=1;i<n;i++)	for(j=0;j<n;j++)	h[j]=(h[j]+f[1][i][j]*i)%P;
	h[n-1]=1;
	for(i=0;i<n-1;i++)	h[i]=h[i]*bt[n-i-2]%P;
	for(i=n-1;i>=0;i--)
	{
		for(j=i+1;j<n;j++)	h[i]=(h[i]-c[j][i]*h[j])%P;
		h[i]=(h[i]+P)%P;
	}
	for(i=0;i<n;i++)	printf("%lld ",h[i]);
	return 0;
}
posted @ 2018-02-21 14:59  CQzhangyu  阅读(389)  评论(0编辑  收藏  举报