P4491 [HAOI2018]染色
题目描述
为了报答小 C 的苹果, 小 G 打算送给热爱美术的小 C 一块画布, 这块画布可 以抽象为一个长度为 \(N\) 的序列, 每个位置都可以被染成 \(M\) 种颜色中的某一种.
然而小 C 只关心序列的 \(N\) 个位置中出现次数恰好为 \(S\) 的颜色种数, 如果恰好出现了 \(S\) 次的颜色有 \(K\) 种, 则小 C 会产生 \(W_k\) 的愉悦度.
小 C 希望知道对于所有可能的染色方案, 他能获得的愉悦度的和对 \(1004535809\) 取模的结果是多少.
数据范围:\(n\leq 10^7,m\leq 10^5,S\leq 150,0\leq w_i < 1004535809\)
solution
容斥+二项式反演+\(NTT\)。
设 \(g(k)\) 表示恰好出现了 \(S\) 次的颜色恰好有 \(k\) 种的方案数。
那么答案显然为:\(\displaystyle\sum_{i=0}^{m} g(k)\times w_k\) 。
考虑怎么求出 \(g(k)\) 来,因为恰好不太好求,所以可以用二项式反演转化一下。
设 \(f(k)\) 表示恰好出现了 \(S\) 次的颜色至少有 \(k\) 种的方案数。
显然: \(\displaystyle f(k) = {m\choose k} \times {n\choose ks} \times {ks!\over (s!)^{k}}\times {(m-k)^{n-ks}}\)
解释一下上面的柿子是怎么来的。
首先我们先从 \(m\) 个颜色里面选出 \(k\) 个,让他恰好出现 \(m\) 次,方案数显然为 \(\displaystyle {m\choose k}\), 之后我们还要在 \(n\) 个位置里面选出 \(ks\) 个位置,使他们的颜色为这 \(k\) 中颜色的一种,方案数为 \(\displaystyle {n\choose ks}\), 又因为每个颜色之间的顺序不同,所以还要在乘上一个 \(\displaystyle {ks!\over ({s!})^i}\) ,剩下的 \(m-ks\) 的位置中可以涂 \(m-k\) 种颜色的任意一种,方案数为 \(\displaystyle {(m-k)^{n-ks}}\)。
显然有:\(\displaystyle {f(k) = \sum_{i=k}^{min(n,m/s)}} {i\choose k} \times g(i)\)
根据二项式反演可得(以下默认 \(upper = {min(n,m/s)}\)):
\(\displaystyle g(k) = \sum_{i=k}^{upper} (-1)^{i-k}\times {i\choose k} \times f(i)\)
把组合数拆开尝试构造一下卷积:
\(\displaystyle {g(k) = \sum_{i=k}^{upper} (-1)^{i-k} \times {i!\over {k!\times (i-k!)}}\times f(i)}\)
把 \(k!\) 移到左边去可得:
\(\displaystyle g(k)\times k! = \sum_{i=k}^{upper} {(-1)^{i-k}\over (i-k)!} \times i!f(i)\)
我们构造多项式,\(A(x),B(x),C(x)\) 其中:
\(\displaystyle A(x) = \sum_{i=0}^{\infin} g(i)i!\times x^i\)
\(\displaystyle B(x) = \sum_{i=0}^{\infin} i!f(i) x^i\)
\(\displaystyle C(x) = \sum_{i=0}^{\infin} {{(-1)}^{i}\over i!} x^i\)
我们会发现这其实是一个差卷积的形式,我们可以通过构造一下使他变为正常的加法卷积的形式。
具体来说:首先把 \(B(x)\) 的每一项系数反转得到 \(B'(x)\) ,设 \(A'(x) = B'(x) * C(x)\), 那么 \(A'(x)\) 第 \(i\) 项的系数其实就是 \(A(x)\) 第 \(n-i\) 项的系数。
简单证明一下,原来的时候是 \(B[i] \times C[i-k] = A[k]\), 经过反转后变为:\(B[n-i]\times C[i-k] = A'[k’]\) ,因为我们后面的是加法卷积的形式,所以 \(k' = n-i+i-k = n-k\) 。因此 \(A'(x)\) 的第 \(i\) 项的系数其实就是 \(A(x)\) 第 \(n-i\) 项的系数。
我们用 \(NTT\) 求出来这个 \(A'(x)\) 之后,把系数反转一下就可以得到 \(A(x)\) 。
预处理出 \(f(i)\) 和 \(g(i)\) 就可以直接求了。
复杂度:\(O(nlogn)\)
注意:卷积数组要开大点,开 \(1e7\) 差不多就可以过了。
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
#define int long long
const int p = 1004535809;
const int N = 1e7+10;
int n,m,s,len,ans,w[N],rev[N],jz[N],inv[N],a[N],b[N];
inline int read()
{
int s = 0,w = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
return s * w;
}
int ksm(int a,int b)
{
int res = 1;
for(; b; b >>= 1)
{
if(b & 1) res = res * a % p;
a = a * a % p;
}
return res;
}
int C(int n,int m)
{
return jz[n] * inv[m] % p * inv[n-m] % p;
}
void NTT(int *a,int lim,int opt)
{
for(int i = 0; i < lim; i++)
{
if(i < rev[i]) swap(a[i],a[rev[i]]);
}
for(int h = 1; h < lim; h <<= 1)
{
int wn = ksm(3,(p-1)/(h<<1));
if(opt == -1) wn = ksm(wn,p-2);
for(int j = 0; j < lim; j += (h<<1))
{
int w = 1;
for(int k = 0; k < h; k++)
{
int u = a[j + k];
int v = w * a[j + h + k] % p;
a[j + k] = (u + v) % p;
a[j + h + k] = (u - v + p) % p;
w = w * wn % p;
}
}
}
if(opt == -1)
{
int inv = ksm(lim,p-2);
for(int i = 0; i < lim; i++) a[i] = a[i] * inv % p;
}
}
signed main()
{
n = read(); m = read(); s = read(); len = min(m,n/s) + 1;
jz[0] = inv[0] = 1;
for(int i = 0; i <= m; i++) w[i] = read();
for(int i = 1; i <= N-5; i++) jz[i] = jz[i-1] * i % p;
inv[N-5] = ksm(jz[N-5],p-2);
for(int i = N-6; i >= 1; i--) inv[i] = inv[i+1] * (i+1) % p;
for(int i = 0; i < len; i++)
{
int tmp = ksm(ksm(jz[s],i),p-2);
a[i] = C(m,i) * C(n,i*s) % p * jz[i*s] % p * tmp % p * ksm(m-i,n-i*s) % p;
a[i] = a[i] * jz[i];
}
for(int i = 0, tmp = 1; i < len; i++, tmp *= -1) b[i] = (tmp * inv[i] + p) % p;
reverse(a,a+len);
int lim = 1, tim = 0;
while(lim < (len<<1)) tim++, lim <<= 1;
for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
NTT(a,lim,1); NTT(b,lim,1);
for(int i = 0; i < lim; i++) a[i] = a[i] * b[i] % p;
NTT(a,lim,-1);
reverse(a,a+len);
for(int i = 0; i < len; i++) a[i] = a[i] * inv[i] % p;
for(int i = 0; i < len; i++) ans = (ans + w[i] * a[i] % p) % p;
printf("%lld\n",ans);
return 0;
}