[HAOI2018]染色
题目描述
为了报答小 C 的苹果, 小 G 打算送给热爱美术的小 C 一块画布, 这块画布可 以抽象为一个长度为 N 的序列, 每个位置都可以被染成 M 种颜色中的某一种.
然而小 C 只关心序列的 N 个位置中出现次数恰好为 S 的颜色种数, 如果恰 好出现了 S 次的颜色有 K 种, 则小 C 会产生 Wk 的愉悦度.
小 C 希望知道对于所有可能的染色方案, 他能获得的愉悦度的和对 1004535809取模的结果是多少。
题解
碰到这种等于什么什么的题要考虑容斥。
既然是容斥,那么先设cnt[i]表示至少有i种出现了s次的颜色的方案数。
这个可以直接算,先C(m,i)选颜色,然后每种颜色的s种看做一个整体,其余的n-i*s看做一个整体,算一下排列n!/((s!)i*(n-s*i)!)先求一下排列数。
然后剩下的n-s*i个位置可以随便填m-i种颜色,乘上(m-i)n-s*i就可以了。
然后继续容斥。
ans[k]=∑(-1)i-kC(i,k)*cnt[i]
然后我们把组合数拆开,用FFT优化就好了。
不过我们发现它的卷积是个反着的,把其中一个数组reverse一下就好了。
代码
#include<iostream> #include<cstdio> #define N 10000009 #define M 100009 using namespace std; typedef long long ll; const int mod=1004535809; const int Gi=334845270; const int G=3; ll n,m,jie[N],ni[N],a[M<<2],val[M<<2],w[M],s,b[M<<2],l,L,rev[M<<2],ans; inline int rd(){ ll x=0;char c=getchar();bool f=0; while(!isdigit(c)){if(c=='-')f=1;c=getchar();} while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();} return f?-x:x; } inline ll power(ll x,ll y){ ll ans=1; while(y){if(y&1)ans=ans*x%mod;x=x*x%mod;y>>=1;} return ans; } inline ll ny(ll x){return power(x,mod-2);} inline ll C(int n,int m){return jie[n]*ni[m]%mod*ni[n-m]%mod;} inline void NTT(ll *a,int tag){ for(int i=1;i<l;++i)if(i>rev[i])swap(a[i],a[rev[i]]); for(int i=1;i<l;i<<=1){ ll wn=power(tag==1?G:Gi,(mod-1)/(i<<1)); for(int j=0;j<l;j+=(i<<1)){ ll w=1; for(int k=0;k<i;++k,w=w*wn%mod){ ll x=a[k+j],y=a[i+j+k]*w%mod; a[k+j]=(x+y)%mod;a[i+j+k]=(x-y+mod)%mod; } } } } int main(){ n=rd();m=rd();s=rd();int num=max(n,m); for(int i=0;i<=m;++i)w[i]=rd(); jie[0]=1; for(int i=1;i<=num;++i)jie[i]=jie[i-1]*i%mod;ni[num]=power(jie[num],mod-2); for(int i=num-1;i>=0;--i)ni[i]=ni[i+1]*(i+1)%mod; for(int i=0;i<=m;++i)if(n-i*s>=0){ val[i]=C(m,i)*jie[n]%mod*power(ni[s],i)%mod*ni[n-s*i]%mod*power(m-i,n-s*i)%mod; } for(int i=0;i<=m;++i){ a[m-i]=jie[i]*val[i]; b[i]=ni[i];if(i&1)b[i]=(-b[i]+mod)%mod; } l=1;L=0; while(l<=(m<<1))l<<=1,L++; for(int i=1;i<l;++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1)); NTT(a,1);NTT(b,1); for(int i=0;i<l;++i)a[i]=a[i]*b[i]%mod; NTT(a,-1);int nn=power(l,mod-2); for(int i=0;i<l;++i)a[i]=a[i]*nn%mod; for(int i=0;i<=m;++i)(ans+=a[m-i]*w[i]%mod*ni[i]%mod)%=mod; cout<<ans; return 0; }