P7092 计数题 题解
前置题目:P5748 集合划分计数。我们令 \(Bell_n\) 表示将 \(n\) 个有标号的球划分为若干集合的方案数。且 \(Bell_n=n![x^n]e^{e^x-1}\)。
首先,当 \(k=0\) 时,\(\mu(S)=0\),答案为 \(0\)。
当 \(k=1\) 时,\(\mu(S)=(-1)^{|S|},\varphi(S)=\prod\limits_{x\in S}(x-1)\)。记 \(\pi(S)=\prod\limits_{x\in S}x\),推式子:
当 \(S'=\varnothing\) 时,\(\sum\limits_{\substack{S\subset T\\S\neq \varnothing\\S'\subset S}} 1=2^{|T|}-1\)。否则,我们就能除去 \(S\neq \varnothing\) 的限制,\(\sum\limits_{\substack{S\subset T\\S'\subset S}} 1=2^{|T|-|S'|}\)。
所以式子为:\(\sum\limits_{S'\subset T}\left(\pi(S')(-1)^{|S' |}2^{|T|-|S'|}\right)-1=\prod\limits_{x\in T}(2-x)-1\)。
注意到当 \(n>1\) 时 \(2\in T\),于是答案为 \(-1\),否则答案为 \(0\)。
当 \(k=2\) 时,注意到当 \(S\) 中存在两个有相同质因子的数时答案为 \(0\)。令 \(P=\{p_1,p_2,\dots,p_{\pi(n)}\}\),其中 \(p_i\) 表示第 \(i\) 个素数。
我们枚举 \(S\) 中的质因子集合,式子为:
\(\sum\limits_{\substack{S\subset P\\S\neq \varnothing}}(-1)^{|S|}\prod\limits_{p\in S}(p-1)Bell_{|S|}=\sum\limits_{i=1}^{|P|}(-1)^iBell_i\sum\limits_{\substack{S\subset P\\|S|=i}}\prod\limits_{p\in S}(p-1)\)。
注意到 \(\sum\limits_{\substack{S\subset P\\|S|=i}}\prod\limits_{p\in S}(p-1)=[x^i]\prod\limits_{p\in S} ((p-1)x+1)\),对于这个式子分治 NTT 计算即可。
复杂度 \(O(\pi(n)\log^2 n)=O(\dfrac{n}{\ln n}\log^2n)=O(n\log n)\)。
代码:
#include<bits/stdc++.h>
#define LL long long
#define fr(x) freopen(#x".in","r",stdin);freopen(#x".out","w",stdout);
using namespace std;
const int N=4e6+5,mod=998244353;
int n,k,pr[N],v[N],a[N],b[N],w[N],jc[N],*f[N],mmax;
#define md(x) ((x)>=mod?(x)-mod:(x))
inline int bger(int x){return x|=x>>1,x|=x>>2,x|=x>>4,x|=x>>8,x|=x>>16,x+1;}
inline int ksm(int x,int p){int s=1;for(;p;(p&1)&&(s=1ll*s*x%mod),x=1ll*x*x%mod,p>>=1);return s;}
inline void dao(int *a,int n){for(int i=1;i<n;i++) a[i-1]=1ll*i*a[i]%mod;a[n-1]=0;}
inline void ji(int *a,int n){for(int i=n-1;i>=1;i--) a[i]=1ll*ksm(i,mod-2)*a[i-1]%mod;a[0]=0;}
inline void init(int mmax)
{
for(int i=1,j,k;i<mmax;i<<=1)
for(w[j=i]=1,k=ksm(3,(mod-1)/(i<<1)),j++;j<(i<<1);j++)
w[j]=1ll*w[j-1]*k%mod;
}
inline void DNT(int *a,int mmax)
{
for(int i,j,k=mmax>>1,L,*W,*x,*y,z;k;k>>=1)
for(L=k<<1,i=0;i<mmax;i+=L)
for(j=0,W=w+k,x=a+i,y=x+k;j<k;j++,W++,x++,y++)
*y=1ll*(*x+mod-(z=*y))* *W%mod,*x=md(*x+z);
}
inline void IDNT(int *a,int mmax)
{
for(int i,j,k=1,L,*W,*x,*y,z;k<mmax;k<<=1)
for(L=k<<1,i=0;i<mmax;i+=L)
for(j=0,W=w+k,x=a+i,y=x+k;j<k;j++,W++,x++,y++)
z=1ll* *W* *y%mod,*y=md(*x+mod-z),*x=md(*x+z);
reverse(a+1,a+mmax);
for(int inv=ksm(mmax,mod-2),i=0;i<mmax;i++) a[i]=1ll*a[i]*inv%mod;
}
inline void NTT(int *a,int *b,int n,int m)
{
mmax=bger(n+m);init(mmax);
DNT(a,mmax);DNT(b,mmax);
for(int i=0;i<mmax;i++) a[i]=1ll*a[i]*b[i]%mod;
IDNT(a,mmax);
}
void INV(int num,int *a,int *b)
{
if(num==1) return b[0]=ksm(a[0],mod-2),void();
INV((num+1)>>1,a,b);
int mmax=bger(num<<1);init(mmax);
static int c[N];
for(int i=0;i<num;i++) c[i]=a[i];for(int i=num;i<mmax;i++) c[i]=0;
DNT(c,mmax);DNT(b,mmax);
for(int i=0;i<mmax;i++) b[i]=1ll*(2-1ll*c[i]*b[i]%mod+mod)%mod*b[i]%mod;
IDNT(b,mmax);
for(int i=num;i<mmax;i++) b[i]=0;
}
inline void Ln(int *a,int n){static int b[N];for(int i=0;i<bger(n<<1);i++) b[i]=0;INV(n,a,b);dao(a,n);NTT(a,b,n,n);ji(a,n);for(int i=n;i<bger(n<<1);i++) a[i]=0;}
inline void Exp(int *a,int *b,int n)
{
if(n==1) return b[0]=1,void();
Exp(a,b,(n+1)>>1);static int c[N];for(int i=0;i<bger(n<<1);i++) c[i]=0;
for(int i=0;i<n;i++) c[i]=b[i];Ln(c,n);
for(int i=0;i<n;i++) c[i]=md(mod-c[i]+a[i]);c[0]=md(c[0]+1);
NTT(b,c,n,n);for(int i=n;i<bger(n<<1);i++) b[i]=0;
}
void sol(int l,int r,int wz)
{
f[wz]=new int[r-l+2];
if(l==r) return f[wz][0]=1,f[wz][1]=pr[l]-1,void();
int mid=(l+r)>>1;sol(l,mid,wz<<1);sol(mid+1,r,wz<<1|1);
static int A[N],B[N];for(int i=0;i<=mid-l+1;i++) A[i]=f[wz<<1][i];
for(int i=0;i<=r-mid;i++) B[i]=f[wz<<1|1][i];NTT(A,B,mid-l+1,r-mid);
for(int i=0;i<=r-l+1;i++) f[wz][i]=A[i];for(int i=0;i<mmax;i++) A[i]=B[i]=0;
}
inline void Init(int M)
{
for(int i=2;i<=M;i++)
{
if(!v[i]) pr[++pr[0]]=i;
for(int j=1;j<=pr[0]&&i*pr[j]<=M;j++)
{
v[i*pr[j]]=1;
if(i%pr[j]==0) break;
}
}
}
int main()
{
ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);cin>>n>>k;
if(!k||n<=1) return cout<<"0",0;
if(k==1) return cout<<mod-1,0;Init(n);n=pr[0];
for(int i=jc[0]=1;i<=n;i++) jc[i]=1ll*jc[i-1]*i%mod;
a[n]=ksm(jc[n],mod-2);for(int i=n-1;~i;i--) a[i]=1ll*a[i+1]*(i+1)%mod;a[0]=0;
Exp(a,b,n+1);for(int i=0;i<=n;i++) b[i]=1ll*b[i]*jc[i]%mod;sol(1,n,1);int ans=0;
for(int i=1,t;i<=n;i++) t=1ll*b[i]*f[1][i]%mod,ans=md(ans+((i&1)?mod-t:t));
return cout<<ans,0;
}