[NTT][组合]P4491 [HAOI2018]染色

题面

https://www.luogu.com.cn/problem/P4491

有 n 个空格,每个空格都可以填 m 种颜色,当恰好有 k 种颜色出现了 S 次时,权值为 w[k] ,求所有方案的权值和

分析

推导起来特别爽的题

首先 k 的最大值 $K=min(\left \lfloor \frac{n}{S} \right \rfloor,m)$

设当前至少有 i 种颜色出现了 S 次的方案数为 $a[i]$

则选择 i 种颜色是 $C(m,i)$

选出放这 i 种颜色的位置是 $C(n,iS)$

i 种颜色任意排列,除去同色内部排列是 $\frac{(iS)!}{(S!)^i}$

最后剩下位置任意选非该 i 种颜色是 $(m-i)^{n-iS}$

则 $a[i]=C(m,i)C(n,iS)\frac{(iS)!}{(S!)^i}(m-i)^{n-iS}$

考虑容斥,则有 $ans=\sum_{i=0}^{K} w[i] \sum_{j=i}^{K} (-1)^{j-i}C(j,i) a[j]$

其中 $C(j,i)$ 表示至少有 i 种颜色在至少有 j 种里面被计算了这么多次(二项式定理)

拆组合数

$ans=\sum_{i=0}^{K} w[i] \sum_{j=i}^{K} (-1)^{j-i}\frac{j!}{i!(j-i)!} a[j]$

将无关项移项,相近项合并

$ans=\sum_{i=0}^{K} w[i] i! \sum_{j=i}^{K} (-1)^{j-i}\frac{1}{(j-i)!}\times a[j]\times j!$

然后设 $F(x)=\sum_{i=0}^{K} (-1)^i \frac{1}{i!} x^i$ ,$G(x)=\sum_{i=0}^{K} a[i] i! x^i$

原 ans 后式为 $\sum_{j=i}^{K} F[j-i]\times G[j]$,仍不是卷积形式

将 $G(x)$ 系数高次项与低次项反转,则

$\sum_{j=i}^{K} F[j-i]\times G[K-j]$ 为卷积形式

最后计算的时候取卷积后的函数第 K-i 位系数作为结果,即

$ans=\sum_{i=0}^{K} w[i] i! \frac{H[K-i]}{x^{K-i}}$

代码

#include <iostream>
#include <cstdio>
using namespace std;
typedef long long ll;
const ll P=1004535809ll;
const int N=131072;
const int M=1e7+10;
int n,m,K,S,bit;
ll fact[M],inv[M],f[N<<2],g[N<<2],w[N],ref[N<<2],ans;

ll Pow(ll x,ll y) {ll ans=1;for (;y;y>>=1,x=x*x%P) if (y&1) ans=ans*x%P;return ans;}

ll C(int n,int m ) {return fact[n]*inv[m]%P*inv[n-m]%P;}

void Get_Ref(int bit,int mxb) {for (int i=1;i<mxb;i++) ref[i]=(ref[i>>1]>>1)|((i&1)<<bit-1);}

void NTT(ll *a,int n,int idft) {
    for (int i=0;i<n;i++) if (ref[i]>i) swap(a[ref[i]],a[i]);
    for (int mlen=1;mlen<n;mlen<<=1) {
        ll g1=Pow(3,(P-1)/(mlen<<1));
        if (idft<0) g1=Pow(g1,P-2);
        for (int l=0,len=mlen<<1;l<n;l+=len) {
            ll gk=1;
            for (int i=l;i<l+mlen;i++) {
                ll x=a[i],y=(a[i+mlen]*gk)%P;
                a[i]=(x+y)%P;a[i+mlen]=(x-y+P)%P;
                (gk*=g1)%=P;
            }
        }
    }
    if (idft<0) {
        ll inv=Pow(n,P-2);
        for (int i=0;i<n;i++) (a[i]*=inv)%=P;
    }
}

int main() {
    scanf("%d%d%d",&n,&m,&S);K=min(n/S,m);
    for (int i=0;i<=m;i++) scanf("%lld",&w[i]);
    fact[0]=1;for (int i=1;i<=max(n,m);i++) fact[i]=fact[i-1]*i%P;
    inv[max(n,m)]=Pow(fact[max(n,m)],P-2);for (int i=max(n,m)-1;~i;i--) inv[i]=inv[i+1]*(i+1)%P;
    for (int i=0;i<=K;i++) f[i]=(i&1?P-1:1)*inv[i]%P,g[K-i]=C(m,i)*C(n,i*S)%P*fact[i*S]%P*Pow(inv[S],i)%P*Pow(m-i,n-i*S)%P*fact[i]%P;
    for (int i=K;i;i>>=1,bit++);bit++;
    Get_Ref(bit,1<<bit);
    NTT(f,1<<bit,1);NTT(g,1<<bit,1);
    for (int i=0;i<(1<<bit);i++) (f[i]*=g[i])%=P;
    NTT(f,1<<bit,-1);
    for (int i=0;i<=K-i;i++) swap(f[i],f[K-i]);
    for (int i=0;i<=K;i++) (ans+=w[i]*f[i]%P*inv[i]%P)%=P;
    printf("%lld\n",ans);
}
View Code

 

posted @ 2021-03-30 20:54  Vagari  阅读(64)  评论(0编辑  收藏  举报