[NTT][组合]P4491 [HAOI2018]染色
题面
https://www.luogu.com.cn/problem/P4491
有 n 个空格,每个空格都可以填 m 种颜色,当恰好有 k 种颜色出现了 S 次时,权值为 w[k] ,求所有方案的权值和
分析
推导起来特别爽的题
首先 k 的最大值 $K=min(\left \lfloor \frac{n}{S} \right \rfloor,m)$
设当前至少有 i 种颜色出现了 S 次的方案数为 $a[i]$
则选择 i 种颜色是 $C(m,i)$
选出放这 i 种颜色的位置是 $C(n,iS)$
i 种颜色任意排列,除去同色内部排列是 $\frac{(iS)!}{(S!)^i}$
最后剩下位置任意选非该 i 种颜色是 $(m-i)^{n-iS}$
则 $a[i]=C(m,i)C(n,iS)\frac{(iS)!}{(S!)^i}(m-i)^{n-iS}$
考虑容斥,则有 $ans=\sum_{i=0}^{K} w[i] \sum_{j=i}^{K} (-1)^{j-i}C(j,i) a[j]$
其中 $C(j,i)$ 表示至少有 i 种颜色在至少有 j 种里面被计算了这么多次(二项式定理)
拆组合数
$ans=\sum_{i=0}^{K} w[i] \sum_{j=i}^{K} (-1)^{j-i}\frac{j!}{i!(j-i)!} a[j]$
将无关项移项,相近项合并
$ans=\sum_{i=0}^{K} w[i] i! \sum_{j=i}^{K} (-1)^{j-i}\frac{1}{(j-i)!}\times a[j]\times j!$
然后设 $F(x)=\sum_{i=0}^{K} (-1)^i \frac{1}{i!} x^i$ ,$G(x)=\sum_{i=0}^{K} a[i] i! x^i$
原 ans 后式为 $\sum_{j=i}^{K} F[j-i]\times G[j]$,仍不是卷积形式
将 $G(x)$ 系数高次项与低次项反转,则
$\sum_{j=i}^{K} F[j-i]\times G[K-j]$ 为卷积形式
最后计算的时候取卷积后的函数第 K-i 位系数作为结果,即
$ans=\sum_{i=0}^{K} w[i] i! \frac{H[K-i]}{x^{K-i}}$
代码
#include <iostream> #include <cstdio> using namespace std; typedef long long ll; const ll P=1004535809ll; const int N=131072; const int M=1e7+10; int n,m,K,S,bit; ll fact[M],inv[M],f[N<<2],g[N<<2],w[N],ref[N<<2],ans; ll Pow(ll x,ll y) {ll ans=1;for (;y;y>>=1,x=x*x%P) if (y&1) ans=ans*x%P;return ans;} ll C(int n,int m ) {return fact[n]*inv[m]%P*inv[n-m]%P;} void Get_Ref(int bit,int mxb) {for (int i=1;i<mxb;i++) ref[i]=(ref[i>>1]>>1)|((i&1)<<bit-1);} void NTT(ll *a,int n,int idft) { for (int i=0;i<n;i++) if (ref[i]>i) swap(a[ref[i]],a[i]); for (int mlen=1;mlen<n;mlen<<=1) { ll g1=Pow(3,(P-1)/(mlen<<1)); if (idft<0) g1=Pow(g1,P-2); for (int l=0,len=mlen<<1;l<n;l+=len) { ll gk=1; for (int i=l;i<l+mlen;i++) { ll x=a[i],y=(a[i+mlen]*gk)%P; a[i]=(x+y)%P;a[i+mlen]=(x-y+P)%P; (gk*=g1)%=P; } } } if (idft<0) { ll inv=Pow(n,P-2); for (int i=0;i<n;i++) (a[i]*=inv)%=P; } } int main() { scanf("%d%d%d",&n,&m,&S);K=min(n/S,m); for (int i=0;i<=m;i++) scanf("%lld",&w[i]); fact[0]=1;for (int i=1;i<=max(n,m);i++) fact[i]=fact[i-1]*i%P; inv[max(n,m)]=Pow(fact[max(n,m)],P-2);for (int i=max(n,m)-1;~i;i--) inv[i]=inv[i+1]*(i+1)%P; for (int i=0;i<=K;i++) f[i]=(i&1?P-1:1)*inv[i]%P,g[K-i]=C(m,i)*C(n,i*S)%P*fact[i*S]%P*Pow(inv[S],i)%P*Pow(m-i,n-i*S)%P*fact[i]%P; for (int i=K;i;i>>=1,bit++);bit++; Get_Ref(bit,1<<bit); NTT(f,1<<bit,1);NTT(g,1<<bit,1); for (int i=0;i<(1<<bit);i++) (f[i]*=g[i])%=P; NTT(f,1<<bit,-1); for (int i=0;i<=K-i;i++) swap(f[i],f[K-i]); for (int i=0;i<=K;i++) (ans+=w[i]*f[i]%P*inv[i]%P)%=P; printf("%lld\n",ans); }