[洛谷P4491] HAOI2018 染色
问题描述
为了报答小 C 的苹果, 小 G 打算送给热爱美术的小 C 一块画布, 这块画布可 以抽象为一个长度为 \(N\) 的序列, 每个位置都可以被染成 \(M\) 种颜色中的某一种.
然而小 C 只关心序列的 \(N\) 个位置中出现次数恰好为 \(S\) 的颜色种数, 如果恰 好出现了 \(S\) 次的颜色有 \(K\) 种, 则小 C 会产生 \(W_k\) 的愉悦度.
小 C 希望知道对于所有可能的染色方案, 他能获得的愉悦度的和对 \(1004535809\) 取模的结果是多少。
输入格式
从标准输入读入数据. 第一行三个整数 \(N, M, S\).
接下来一行 \(M + 1\) 个整数, 第 \(i\) 个数表示 \(W_{i-1}\).
输出格式
输出到标准输出中. 输出一个整数表示答案.
样例输入
8 8 3
3999 8477 9694 8454 3308 8961 3018 2255 4910
样例输出
524070430
数据范围
\(N\le 10^7, M\le 10^5, S\le 150\)
解析
考虑容斥。设 \(f_i\) 表示至少有 \(i\) 种颜色出现了 \(S\) 次的方案数。实际上钦定出现了 \(S\) 次的颜色后就是一个可重排列,没有被钦定的颜色任意选择。即:
\[f_i={m\choose i} \frac{n!}{(S!)^i(n-iS)!} (n-iS)^{m-i}
\]
设 \(g_i\) 表示恰好有 \(i\) 种颜色出现了 \(S\) 次的方案数。不难得到:
\[\begin{aligned}g_i&=\sum_{j=i}^m (-1)^{j-i}{j\choose i} f_j\\ &=\sum_{j=i}^m \frac{(-1)^{j-i}}{(j-i)!}\times j!f_j\end{aligned}
\]
差卷积一下即可。
代码
#include <iostream>
#include <cstdio>
#define N 10000002
#define M 500002
#define int long long
using namespace std;
const int mod=1004535809;
const int G=3;
int n,m,n1=1,m1,lim,s,i,w[M],f[M],g[M],r[M],fac[N],inv[N];
int read()
{
char c=getchar();
int w=0;
while(c<'0'||c>'9') c=getchar();
while(c<='9'&&c>='0'){
w=w*10+c-'0';
c=getchar();
}
return w;
}
int poww(int a,int b)
{
int ans=1,base=a;
while(b){
if(b&1) ans=ans*base%mod;
base=base*base%mod;
b>>=1;
}
return ans;
}
int C(int n,int m)
{
return fac[n]*inv[m]%mod*inv[n-m]%mod;
}
void NTT(int *a,int n,int inv)
{
for(int i=0;i<n;i++){
if(i<r[i]) swap(a[i],a[r[i]]);
}
for(int l=2;l<=n;l<<=1){
int mid=l/2;
int cur=poww(G,(mod-1)/l);
if(inv==-1) cur=poww(cur,mod-2);
for(int i=0;i<n;i+=l){
int omg=1;
for(int j=0;j<mid;j++,omg=omg*cur%mod){
int tmp=omg*a[i+j+mid]%mod;
a[i+j+mid]=(a[i+j]-tmp+mod)%mod;
a[i+j]=(a[i+j]+tmp)%mod;
}
}
}
if(inv==-1){
for(int i=0;i<n;i++) a[i]=a[i]*poww(n,mod-2)%mod;
}
}
signed main()
{
n=read();m=read();s=read();
for(i=0;i<=m;i++) w[i]=read();
for(i=fac[0]=1;i<=max(n,m);i++) fac[i]=fac[i-1]*i%mod;
inv[max(n,m)]=poww(fac[max(n,m)],mod-2);
for(i=max(n,m)-1;i>=0;i--) inv[i]=inv[i+1]*(i+1)%mod;
m1=min(m,n/s);
for(i=0;i<=m1;i++) f[i]=fac[i]*C(m,i)%mod*fac[n]%mod*poww(inv[s],i)%mod*inv[n-s*i]%mod*poww(m-i,n-i*s)%mod;
for(i=0;i<=m1;i++){
g[i]=inv[m1-i];
if((m1-i)%2!=0) g[i]=mod-g[i];
}
while(n1<=2*m1) n1<<=1,lim++;
for(i=0;i<n1;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(lim-1));
NTT(f,n1,1);NTT(g,n1,1);
for(i=0;i<n1;i++) f[i]=f[i]*g[i]%mod;
NTT(f,n1,-1);
int ans=0;
for(i=0;i<=m1;i++) ans=(ans+w[i]*f[m1+i]%mod*inv[i]%mod)%mod;
printf("%lld\n",ans);
return 0;
}