【XSY1301】原题的价值 第二类斯特林数 NTT

题目描述

  给你\(n,m\),求所有\(n\)个点的简单无向图中每个点度数的\(m\)次方的和。

  \(n\leq {10}^9,m\leq {10}^5\)

题解

  \(g_n\)\(n\)个点的无向图个数,\(f_n\)\(n\)个点的答案。

\[\begin{align} g_n&=2^{\binom{n}{2}}\\ f_n&=ng_{n-1}\sum_{i=0}^{n-1}\binom{n-1}{i}i^m\\ &=ng_{n-1}\sum_{i=0}^{n-1}\binom{n-1}{i}\sum_{j=0}^{i}\binom{i}{j}S(m,j)j!\\ &=ng_{n-1}\sum_{i=0}^{n-1}\sum_{j=0}^i\binom{n-1}{i}\binom{i}{j}S(m,j)j!\\ &=ng_{n-1}\sum_{i=0}^{n-1}\sum_{j=0}^i\binom{n-j}{j}\binom{n-1-i}{i-j}S(m,j)j!\\ &=ng_{n-1}\sum_{j=0}^m\binom{n-1}{j}S(m,j)j!\sum_{i=j}^{n-1}\binom{n-1-j}{i-j}\\ &=ng_{n-1}\sum_{j=0}^m{(n-1)}^\underline{j}S(m,j)2^{n-1-j}\\ \end{align} \]

  用ntt算斯特林数

  时间复杂度:\(O(m\log m)\)

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
ll p=998244353;
ll fp(ll a,ll b)
{
	ll s=1;
	while(b)
	{
		if(b&1)
			s=s*a%p;
		a=a*a%p;
		b>>=1;
	}
	return s;
}
ll fc[300010];
ll ifc[300010];
ll a[300010];
ll b[300010];
int rev[300010];
void ntt(ll *a,int n,int t)
{
	ll u,v,w,wn;
	int i,j,k;
	rev[0]=0;
	for(i=1;i<n;i++)
		rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
	for(i=0;i<n;i++)
		if(rev[i]<i)
			swap(a[rev[i]],a[i]);
	for(i=2;i<=n;i<<=1)
	{
		if(t==1)
			wn=fp(3,(p-1)/i);
		else
			wn=fp(fp(3,(p-1)/i),p-2);
		for(j=0;j<n;j+=i)
		{
			w=1;
			for(k=j;k<j+i/2;k++)
			{
				u=a[k];
				v=a[k+i/2]*w%p;
				a[k]=(u+v)%p;
				a[k+i/2]=(u-v)%p;
				w=w*wn%p;
			}
		}
	}
	if(t==-1)
	{
		ll inv=fp(n,p-2);
		for(i=0;i<n;i++)
			a[i]=a[i]*inv%p;
	}
}
ll c[300010];
int main()
{
//	freopen("b.in","r",stdin);
//	freopen("b.out","w",stdout);
	int n,m;
	scanf("%d%d",&n,&m);
	fc[0]=fc[1]=ifc[0]=ifc[1]=1;
	int i;
	int t=min(n-1,m);
	for(i=2;i<=t;i++)
	{
		fc[i]=fc[i-1]*i%p;
		ifc[i]=ifc[i-1]*fp(i,p-2)%p;
	}
	for(i=0;i<=t;i++)
	{
		a[i]=(i&1?-1:1)*ifc[i];
		b[i]=fp(i,m)*ifc[i]%p;
	}
	int k=1;
	while(k<=2*t)
		k<<=1;
	ntt(a,k,1);
	ntt(b,k,1);
	for(i=0;i<k;i++)
		a[i]=a[i]*b[i]%p;
	ntt(a,k,-1);
	for(i=0;i<k;i++)
		a[i]=(a[i]%p+p)%p;
	ll ans=0;
	c[0]=1;
	for(i=1;i<=t;i++)
		c[i]=c[i-1]*(n-i)%p;
	for(i=0;i<=t;i++)
		ans=(ans+c[i]%p*a[i]%p*fp(2,n-1-i)%p)%p;
	ans=ans*n%p*fp(2,ll(n-1)*(n-2)/2%(p-1))%p;
	printf("%lld\n",ans);
	return 0;
}
posted @ 2018-03-05 20:55  ywwyww  阅读(401)  评论(0编辑  收藏  举报