题解-CTS2019 珍珠
题面
有 \(n\) 个在 \([1,d]\) 内的整数,求使可以拿出 \(2m\) 个整数凑成 \(m\) 个相等的整数对的方案数。
数据范围:\(0\le m\le 10^9\),\(1\le n\le 10^9\),\(1\le d\le 10^5\)。
蒟蒻语
非常巧妙的题,主要要用到二项式反演、指数级生成函数和 NTT
。
做个广告,这是我读过最好的生成函数讲解:link
。
蒟蒻解
设 \(c_i\) 表示 \(i\) 这个数的出现次数。
设 \(odd=\sum [c_i\in {\rm odd} ]\),即 \(c_i\) 奇数个数。
很明显最多能凑成 \(\frac{n-odd}{2}\) 对,按题意:
\[\begin{aligned}
\frac{n-odd}{2}&\ge m\\
odd&\le n-2m
\end{aligned}
\]
这里有两个特判,如果 \(n-2m<0\) 答案是 \(0\),如果 \(n-2m\ge d\) 答案是 \(d^n\)。
设 \(g(i)\) 表示 \(odd=i\) 的方案数。
设 \(f(i)\) 表示钦定 \(i\) 个 \(c_i\) 是奇数,剩下随意的方案数(不是 \(odd\ge i\) 的方案数,这里对一些排列会重复统计,但是反演完就没事了)。
\[f(i)=\sum_{x=i}^d{x\choose i}g(x)\Longleftrightarrow g(i)=\sum_{x=i}^d(-1)^{x-i}{x\choose i} f(x)
\]
所以可以先求 \(f(i)\),用到了指数级生成函数,中间把每个 \(e\) 的幂次项展开,最后归成卷积形式:
\[\begin{aligned}
f(i)=&{d \choose i} \left(\frac{e^x-e^{-x}}{2}\right)^i e^{(d-i)x} n![n]\\
=&{d \choose i} (e^x-e^{-x})^i e^{(d-i)x} \frac{n!}{2^i}[n]\\
=&{d \choose i} e^{(d-i)x} \sum_{j=0}^i(-1)^{i-j} {i\choose j}e^{(j-(i-j))x} \frac{n!}{2^i}[n]\\
=&{d \choose i} e^{(d-i)x} \sum_{j=0}^i(-1)^{i-j} {i\choose j}e^{(2j-i)x} \frac{n!}{2^i}[n]\\
=&{d \choose i} \frac{n!}{2^i}\sum_{j=0}^i(-1)^{j} {i\choose j}e^{(d-2j)x} [n]\\
=&{d \choose i} \frac{n!}{2^i}\sum_{j=0}^i(-1)^{j} {i\choose j} \frac{(d-2j)^n}{n!}\\
=&\frac{d!}{i!(d-i)!2^i}\sum_{j=0}^i(-1)^{j} \frac{i!}{j!(i-j)!} (d-2j)^n\\
=&\frac{d!}{(d-i)!2^i}\sum_{j=0}^i\frac{(-1)^{j}(d-2j)^n}{j!}\cdot \frac{1}{(i-j)!} \\
\end{aligned}
\]
最后一个难点是如何求:
\[g(i)=\sum_{x=i}^d(-1)^{x-i}{x\choose i} f(x)
\]
感觉可以凑成卷积形式,但总差一点。尝试把 \(f\) 和 \(g\) 都反过来,即令 \(f'(x)=f(d-x)\),\(g'(x)=g(d-x)\):
\[\begin{aligned}
g(i)=&\sum_{x=i}^d(-1)^{x-i}{x\choose i} f(x)\\
=&\sum_{x=i}^d(-1)^{x-i}{x\choose i} f'(d-x)\\
=&\sum_{x=0}^{d-i}(-1)^{d-x-i}{d-x\choose i} f'(x)\\
\end{aligned}\\
g'(d-i)=\sum_{x=0}^{d-i}(-1)^{d-x-i}{d-x\choose i} f'(x)\\
\begin{aligned}
g'(i)=&\sum_{x=0}^{i}(-1)^{i-x}{d-x\choose d-i} f'(x)\\
=&\sum_{x=0}^{i}(-1)^{i-x}\frac{(d-x)!}{(d-i)!(i-x)!} f'(x)\\
=&\frac{1}{(d-i)!}\sum_{x=0}^{i} (-1)^{i-x}\frac{1}{(i-x)!}\cdot f'(x)(d-x)!\\
\end{aligned}\\
\]
然后就可以求 \(g(i)\) 了。
答案便是 \(\sum_{i=0}^{n-2m} g(i)\)。
代码
#include <bits/stdc++.h>
using namespace std;
//Start
typedef long long ll;
typedef double db;
#define mp(a,b) make_pair((a),(b))
#define x first
#define y second
#define bg begin()
#define ed end()
#define sz(a) int((a).size())
#define pb(a) push_back(a)
#define R(i,a,b) for(int i=(a),i##E=(b);i<i##E;i++)
#define L(i,a,b) for(int i=(b)-1,i##E=(a)-1;i>i##E;i--)
const int iinf=0x3f3f3f3f;
const ll linf=0x3f3f3f3f3f3f3f3f;
//Data
const int N=1e5;
int n,m,d,ans;
//Math
const int mod=998244353;
int Pow(int a,int x){
int res=1; for(;x;x>>=1,a=1ll*a*a%mod)
if(x&1) res=1ll*res*a%mod; return res;
}
const int mN=N+1;
int fac[mN],ifac[mN];
void math_init(){
fac[0]=1;R(i,1,d+1) fac[i]=1ll*fac[i-1]*i%mod;
ifac[d]=Pow(fac[d],mod-2);
L(i,0,d) ifac[i]=1ll*ifac[i+1]*(i+1)%mod;
}
//Poly
const int pN=mN<<2;
int f[pN],g[pN];
const int G=3,iG=Pow(3,mod-2);
int rev[pN],pn;
void poly_init(){
pn=1<<int(ceil(log2(d*2+2)));
R(i,0,pn) rev[i]=(rev[i>>1]>>1)|((i&1)*(pn>>1));
}
void NTT(int* a,int t){
R(i,0,pn)if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int mid=1;mid<pn;mid<<=1){
int wn=Pow(~t?G:iG,(mod-1)/(mid<<1));
for(int i=0;i<pn;i+=(mid<<1)){
int w=1;
R(j,i,mid+i){
int x=a[j],y=1ll*a[mid+j]*w%mod;
a[j]=(x+y)%mod,a[mid+j]=(x-y+mod)%mod,w=1ll*w*wn%mod;
}
}
}
if(!~t){
int in=Pow(pn,mod-2);
R(i,0,pn) a[i]=1ll*a[i]*in%mod;
}
}
//Main
int main(){
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0);
cin>>d>>n>>m;
if(n-2*m<0) return cout<<0<<'\n',0;
if(d<=n-2*m) return cout<<Pow(d,n)<<'\n',0;
math_init(),poly_init();
R(i,0,d+1){
f[i]=1ll*Pow((d-2*i+mod)%mod,n)*ifac[i]%mod;
if(i&1) f[i]=(mod-f[i])%mod;
g[i]=ifac[i];
}
R(i,d+1,pn) f[i]=g[i]=0;
NTT(f,1),NTT(g,1);
R(i,0,pn) f[i]=1ll*f[i]*g[i]%mod;
NTT(f,-1);
R(i,d+1,pn) f[i]=g[i]=0;
R(i,0,d+1) f[i]=1ll*f[i]*fac[d]%mod*ifac[d-i]%mod*Pow(2,mod-1-i)%mod;
reverse(f,f+d+1);
R(i,0,d+1) f[i]=1ll*f[i]*fac[d-i]%mod;
R(i,0,d+1){
g[i]=ifac[i];
if(i&1) g[i]=(mod-g[i])%mod;
}
R(i,d+1,pn) f[i]=g[i]=0;
NTT(f,1),NTT(g,1);
R(i,0,pn) g[i]=1ll*f[i]*g[i]%mod;
NTT(g,-1);
R(i,d+1,pn) f[i]=g[i]=0;
R(i,0,d+1) g[i]=1ll*g[i]*ifac[d-i]%mod;
reverse(g,g+d+1);
R(i,0,n-2*m+1) (ans+=g[i])%=mod;
cout<<ans<<'\n';
return 0;
}
祝大家学习愉快!