【XSY1301】原题的价值 第二类斯特林数 NTT
题目描述
给你\(n,m\),求所有\(n\)个点的简单无向图中每个点度数的\(m\)次方的和。
\(n\leq {10}^9,m\leq {10}^5\)
题解
\(g_n\)为\(n\)个点的无向图个数,\(f_n\)为\(n\)个点的答案。
\[\begin{align}
g_n&=2^{\binom{n}{2}}\\
f_n&=ng_{n-1}\sum_{i=0}^{n-1}\binom{n-1}{i}i^m\\
&=ng_{n-1}\sum_{i=0}^{n-1}\binom{n-1}{i}\sum_{j=0}^{i}\binom{i}{j}S(m,j)j!\\
&=ng_{n-1}\sum_{i=0}^{n-1}\sum_{j=0}^i\binom{n-1}{i}\binom{i}{j}S(m,j)j!\\
&=ng_{n-1}\sum_{i=0}^{n-1}\sum_{j=0}^i\binom{n-j}{j}\binom{n-1-i}{i-j}S(m,j)j!\\
&=ng_{n-1}\sum_{j=0}^m\binom{n-1}{j}S(m,j)j!\sum_{i=j}^{n-1}\binom{n-1-j}{i-j}\\
&=ng_{n-1}\sum_{j=0}^m{(n-1)}^\underline{j}S(m,j)2^{n-1-j}\\
\end{align}
\]
用ntt算斯特林数
时间复杂度:\(O(m\log m)\)
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
ll p=998244353;
ll fp(ll a,ll b)
{
ll s=1;
while(b)
{
if(b&1)
s=s*a%p;
a=a*a%p;
b>>=1;
}
return s;
}
ll fc[300010];
ll ifc[300010];
ll a[300010];
ll b[300010];
int rev[300010];
void ntt(ll *a,int n,int t)
{
ll u,v,w,wn;
int i,j,k;
rev[0]=0;
for(i=1;i<n;i++)
rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
for(i=0;i<n;i++)
if(rev[i]<i)
swap(a[rev[i]],a[i]);
for(i=2;i<=n;i<<=1)
{
if(t==1)
wn=fp(3,(p-1)/i);
else
wn=fp(fp(3,(p-1)/i),p-2);
for(j=0;j<n;j+=i)
{
w=1;
for(k=j;k<j+i/2;k++)
{
u=a[k];
v=a[k+i/2]*w%p;
a[k]=(u+v)%p;
a[k+i/2]=(u-v)%p;
w=w*wn%p;
}
}
}
if(t==-1)
{
ll inv=fp(n,p-2);
for(i=0;i<n;i++)
a[i]=a[i]*inv%p;
}
}
ll c[300010];
int main()
{
// freopen("b.in","r",stdin);
// freopen("b.out","w",stdout);
int n,m;
scanf("%d%d",&n,&m);
fc[0]=fc[1]=ifc[0]=ifc[1]=1;
int i;
int t=min(n-1,m);
for(i=2;i<=t;i++)
{
fc[i]=fc[i-1]*i%p;
ifc[i]=ifc[i-1]*fp(i,p-2)%p;
}
for(i=0;i<=t;i++)
{
a[i]=(i&1?-1:1)*ifc[i];
b[i]=fp(i,m)*ifc[i]%p;
}
int k=1;
while(k<=2*t)
k<<=1;
ntt(a,k,1);
ntt(b,k,1);
for(i=0;i<k;i++)
a[i]=a[i]*b[i]%p;
ntt(a,k,-1);
for(i=0;i<k;i++)
a[i]=(a[i]%p+p)%p;
ll ans=0;
c[0]=1;
for(i=1;i<=t;i++)
c[i]=c[i-1]*(n-i)%p;
for(i=0;i<=t;i++)
ans=(ans+c[i]%p*a[i]%p*fp(2,n-1-i)%p)%p;
ans=ans*n%p*fp(2,ll(n-1)*(n-2)/2%(p-1))%p;
printf("%lld\n",ans);
return 0;
}