【XSY2843】「地底蔷薇」 NTT什么的 扩展拉格朗日反演
题目大意
给定集合\(S\),请你求出\(n\)个点的“所有极大点双连通分量的大小都在\(S\)内”的不同简单无向连通图的个数对\(998244353\)取模的结果。
\(n\leq {10}^5,(m=\sum_{x\in S})\leq {10}^5\)
题解
首先你要会求\(n\)个点带标号有根简单无向图的个数。bzoj3456就是求这个东西。
记\(H(x)\)为带标号有根简单无向图个数的EGF。
记\(b_i\)为\(i+1\)个点的带标号点双个数,\(B(x)=\sum_{i\geq 0}\frac{b_i}{i!}x^i\)。
考虑一个有根连通图是长怎样的。
先把根删掉,然后整个图会分为很多个连通块。每个连通块内部都有一些点和根在同一个点双内,把点双里面的所有边删掉之后整个点双会分成很多个以这些点为根的连通图。我们枚举单个点双还剩下多少个点,则单个连通块的答案是
把所有连通块合在一起,有
现在我们知道\(H(x)\),要求\(B(x)\)中某些项的系数。
记\(H^{-1}(x)\)为\(H(x)\)的复合逆。
然后就有
如果直接用拉格朗日反演求\(H^{-1}(x)\)的系数再求\(B(x)\)的话,要求出全部\(O(n)\)项,要花费\(O(n^2)\)的时间。这太慢了。
有个东西叫扩展拉格朗日反演:
我们要构造\(G(x)\)使得\(G(H^{-1}(x))=B(x)\)。
所以我们就可以求出\(G(x)\),然后在\(O(n\log n)\)内求出\([x^n]B(x)=[x^n]G(F^{-1}(x))\)了。
因为\(m\)是\(\leq {10}^5\)的,总复杂度就是\(O(m\log m)\)
记
,\(C(x)\)为满足题目要求的带标号有根简单无向图个数的EGF,那么满足
再做一次拉格朗日反演就可以得到\([x^n]C(x)\)了。
最终答案为\((n-1)![x^n]C(x)\)。因为我们求的是EGF所以要乘以\(n!\),然后是有根变无根要除以一个\(n\)。
时间复杂度:\(O(n\log n+m\log m)\)
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int p=998244353;
const int N=300000;
const int W=262144;
ll fp(ll a,ll b){ll s=1;for(;b;b>>=1,a=a*a%p)if(b&1)s=s*a%p;return s;}
int iv[N];
int ifac[N];
int fac[N];
int w[W];
void ntt(int *a,int n,int t)
{
static int rev[N];
for(int i=1;i<n;i++)
{
rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
if(rev[i]>i)
swap(a[i],a[rev[i]]);
}
for(int i=2;i<=n;i<<=1)
for(int j=0;j<n;j+=i)
for(int k=0;k<i/2;k++)
{
int u=a[j+k];
int v=(ll)a[j+k+i/2]*w[W/i*k]%p;
a[j+k]=(u+v)%p;
a[j+k+i/2]=(u-v)%p;
}
if(t==-1)
{
reverse(a+1,a+n);
ll invn=fp(n,p-2);
for(int i=0;i<n;i++)
a[i]=a[i]*invn%p;
}
}
void copy(int *a,int *b,int l,int r){memcpy(a+l,b+l,sizeof(a[0])*(r-l));}
void clear(int *a,int l,int r){memset(a+l,0,sizeof(a[0])*(r-l));}
void mul(int *a,int *b,int *c,int n,int m,int l=-1)
{
static int a1[N],a2[N];
if(l==-1)
l=n+m;
n=min(n,l);
m=min(m,l);
int k=1;
while(k<=n+m)
k<<=1;
copy(a1,a,0,n+1);
clear(a1,n+1,k);
copy(a2,b,0,m+1);
clear(a2,m+1,k);
ntt(a1,k,1);
ntt(a2,k,1);
for(int i=0;i<k;i++)
a1[i]=(ll)a1[i]*a2[i]%p;
ntt(a1,k,-1);
copy(c,a1,0,l+1);
}
void mul2(int *a,int *b,int *c,int n)
{
mul(a,b,c,n-1,n-1,n-1);
}
void inv(int *a,int *b,int n)
{
if(n==1)
{
b[0]=fp(a[0],p-2);
return;
}
inv(a,b,n>>1);
static int a1[N],a2[N];
copy(a1,a,0,n);
clear(a1,n,n<<1);
copy(a2,b,0,n>>1);
clear(a2,n>>1,n<<1);
ntt(a1,n<<1,1);
ntt(a2,n<<1,1);
for(int i=0;i<n<<1;i++)
a1[i]=a2[i]*(2-(ll)a1[i]*a2[i]%p)%p;
ntt(a1,n<<1,-1);
copy(b,a1,0,n);
}
void ln(int *a,int *b,int n)
{
static int a1[N],a2[N];
for(int i=1;i<n;i++)
a1[i-1]=(ll)a[i]*i%p;
a1[n-1]=0;
inv(a,a2,n);
mul2(a1,a2,a1,n);
for(int i=1;i<n;i++)
b[i]=(ll)a1[i-1]*iv[i]%p;
b[0]=0;
}
void exp(int *a,int *b,int n)
{
if(n==1)
{
b[0]=1;
return;
}
exp(a,b,n>>1);
static int a1[N],a2[N];
clear(b,n>>1,n);
ln(b,a1,n);
for(int i=0;i<n>>1;i++)
a1[i]=(a[i+(n>>1)]-a1[i+(n>>1)])%p;
mul2(a1,b,a2,n>>1);
for(int i=0;i<n>>1;i++)
b[i+(n>>1)]=a2[i];
}
void pow(int *a,int *b,int n,int m)
{
static int a1[N],a2[N],a3[N];
int k=1;
while(k<=n)
k<<=1;
copy(a1,a,0,n+1);
clear(a1,n+1,k);
ln(a1,a2,k);
for(int i=0;i<k;i++)
a2[i]=(ll)a2[i]*m%p;
exp(a2,a3,k);
copy(b,a3,0,n+1);
}
int a[N],b[N],g[N],h[N],f[N];
int n,m;
void geth()
{
int k=1;
while(k<=n+2)
k<<=1;
for(int i=0;i<=n+2;i++)
f[i]=fp(2,ll(i-1)*i/2)*ifac[i]%p;
ln(f,h,k);
for(int i=0;i<=n+1;i++)
h[i]=(ll)h[i]*i%p;
}
void getg()
{
static int a1[N];
for(int i=0;i<=n;i++)
a1[i]=h[i+1];
int k=1;
while(k<=n)
k<<=1;
ln(a1,g,k);
for(int i=1;i<=n;i++)
g[i-1]=(ll)g[i]*i%p;
}
int getb(int x)
{
int k=1;
while(k<x)
k<<=1;
static int a1[N],a2[N];;
for(int i=0;i<x;i++)
a1[i]=(ll)h[i]*(-x)%p;
exp(a1,a2,k);
mul(g,a2,a1,x-1,x-1,x-1);
return (ll)a1[x-1]*iv[x]%p;
}
int geta()
{
static int a1[N];
int k=1;
while(k<n)
k<<=1;
for(int i=0;i<k;i++)
b[i]=(ll)b[i]*n%p;
exp(b,a1,k);
return (ll)a1[n-1]*iv[n]%p;
}
int c[N];
int main()
{
#ifndef ONLINE_JUDGE
freopen("d.in","r",stdin);
freopen("d.out","w",stdout);
#endif
fac[0]=fac[1]=ifac[0]=ifac[1]=iv[1]=1;
for(int i=2;i<=W;i++)
{
iv[i]=(ll)-p/i*iv[p%i]%p;
ifac[i]=(ll)ifac[i-1]*iv[i]%p;
fac[i]=(ll)fac[i-1]*i%p;
}
ll w1=fp(3,(p-1)/W);
w[0]=1;
for(int i=1;i<W;i++)
w[i]=w[i-1]*w1%p;
scanf("%d%d",&n,&m);
geth();
getg();
int x;
int k=1;
for(int i=1;i<=m;i++)
{
scanf("%d",&c[i]);
c[i]--;
while(k<c[i])
k<<=1;
}
for(int i=0;i<k;i++)
h[i]=h[i+1];
ln(h,h,k);
for(int i=1;i<=m;i++)
b[c[i]]=getb(c[i]);
ll ans=geta();
ans=ans*fac[n-1]%p;
ans=(ans+p)%p;
printf("%lld\n",ans);
return 0;
}