P4491 [HAOI2018]染色

题目描述

洛谷

为了报答小 C 的苹果, 小 G 打算送给热爱美术的小 C 一块画布, 这块画布可 以抽象为一个长度为 \(N\) 的序列, 每个位置都可以被染成 \(M\) 种颜色中的某一种.

然而小 C 只关心序列的 \(N\) 个位置中出现次数恰好为 \(S\) 的颜色种数, 如果恰好出现了 \(S\) 次的颜色有 \(K\) 种, 则小 C 会产生 \(W_k\) 的愉悦度.

小 C 希望知道对于所有可能的染色方案, 他能获得的愉悦度的和对 \(1004535809\) 取模的结果是多少.

数据范围:\(n\leq 10^7,m\leq 10^5,S\leq 150,0\leq w_i < 1004535809\)

solution

容斥+二项式反演+\(NTT\)

\(g(k)\) 表示恰好出现了 \(S\) 次的颜色恰好有 \(k\) 种的方案数。

那么答案显然为:\(\displaystyle\sum_{i=0}^{m} g(k)\times w_k\)

考虑怎么求出 \(g(k)\) 来,因为恰好不太好求,所以可以用二项式反演转化一下。

\(f(k)\) 表示恰好出现了 \(S\) 次的颜色至少有 \(k\) 种的方案数。

显然: \(\displaystyle f(k) = {m\choose k} \times {n\choose ks} \times {ks!\over (s!)^{k}}\times {(m-k)^{n-ks}}\)

解释一下上面的柿子是怎么来的。

首先我们先从 \(m\) 个颜色里面选出 \(k\) 个,让他恰好出现 \(m\) 次,方案数显然为 \(\displaystyle {m\choose k}\), 之后我们还要在 \(n\) 个位置里面选出 \(ks\) 个位置,使他们的颜色为这 \(k\) 中颜色的一种,方案数为 \(\displaystyle {n\choose ks}\), 又因为每个颜色之间的顺序不同,所以还要在乘上一个 \(\displaystyle {ks!\over ({s!})^i}\) ,剩下的 \(m-ks\) 的位置中可以涂 \(m-k\) 种颜色的任意一种,方案数为 \(\displaystyle {(m-k)^{n-ks}}\)

显然有:\(\displaystyle {f(k) = \sum_{i=k}^{min(n,m/s)}} {i\choose k} \times g(i)\)

根据二项式反演可得(以下默认 \(upper = {min(n,m/s)}\)):

\(\displaystyle g(k) = \sum_{i=k}^{upper} (-1)^{i-k}\times {i\choose k} \times f(i)\)

把组合数拆开尝试构造一下卷积:

\(\displaystyle {g(k) = \sum_{i=k}^{upper} (-1)^{i-k} \times {i!\over {k!\times (i-k!)}}\times f(i)}\)

\(k!\) 移到左边去可得:

\(\displaystyle g(k)\times k! = \sum_{i=k}^{upper} {(-1)^{i-k}\over (i-k)!} \times i!f(i)\)

我们构造多项式,\(A(x),B(x),C(x)\) 其中:

\(\displaystyle A(x) = \sum_{i=0}^{\infin} g(i)i!\times x^i\)

\(\displaystyle B(x) = \sum_{i=0}^{\infin} i!f(i) x^i\)

\(\displaystyle C(x) = \sum_{i=0}^{\infin} {{(-1)}^{i}\over i!} x^i\)

我们会发现这其实是一个差卷积的形式,我们可以通过构造一下使他变为正常的加法卷积的形式。

具体来说:首先把 \(B(x)\) 的每一项系数反转得到 \(B'(x)\) ,设 \(A'(x) = B'(x) * C(x)\), 那么 \(A'(x)\)\(i\) 项的系数其实就是 \(A(x)\)\(n-i\) 项的系数。

简单证明一下,原来的时候是 \(B[i] \times C[i-k] = A[k]\), 经过反转后变为:\(B[n-i]\times C[i-k] = A'[k’]\) ,因为我们后面的是加法卷积的形式,所以 \(k' = n-i+i-k = n-k\) 。因此 \(A'(x)\) 的第 \(i\) 项的系数其实就是 \(A(x)\)\(n-i\) 项的系数。

我们用 \(NTT\) 求出来这个 \(A'(x)\) 之后,把系数反转一下就可以得到 \(A(x)\)

预处理出 \(f(i)\)\(g(i)\) 就可以直接求了。

复杂度:\(O(nlogn)\)

注意:卷积数组要开大点,开 \(1e7\) 差不多就可以过了。

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
#define int long long
const int p = 1004535809;
const int N = 1e7+10;
int n,m,s,len,ans,w[N],rev[N],jz[N],inv[N],a[N],b[N];
inline int read()
{
    int s = 0,w = 1; char ch = getchar();
    while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
    return s * w;
}
int ksm(int a,int b)
{
    int res = 1;
    for(; b; b >>= 1)
    {
        if(b & 1) res = res * a % p;
        a = a * a % p;
    }
    return res;
}
int C(int n,int m)
{
    return jz[n] * inv[m] % p * inv[n-m] % p;
}
void NTT(int *a,int lim,int opt)
{
    for(int i = 0; i < lim; i++)
    {
        if(i < rev[i]) swap(a[i],a[rev[i]]);
    }
    for(int h = 1; h < lim; h <<= 1)
    {
        int wn = ksm(3,(p-1)/(h<<1));
        if(opt == -1) wn = ksm(wn,p-2);
        for(int j = 0; j < lim; j += (h<<1))
        {
            int w = 1;
            for(int k = 0; k < h; k++)
            {
                int u = a[j + k];
                int v = w * a[j + h + k] % p;
                a[j + k] = (u + v) % p;
                a[j + h + k] = (u - v + p) % p;
                w = w * wn % p;
            }
        }
    }
    if(opt == -1)
    {
        int inv = ksm(lim,p-2);
        for(int i = 0; i < lim; i++) a[i] = a[i] * inv % p;
    }
}
signed main()
{
    n = read(); m = read(); s = read(); len = min(m,n/s) + 1;
    jz[0] = inv[0] = 1; 
    for(int i = 0; i <= m; i++) w[i] = read();
    for(int i = 1; i <= N-5; i++) jz[i] = jz[i-1] * i % p;
    inv[N-5] = ksm(jz[N-5],p-2);
    for(int i = N-6; i >= 1; i--) inv[i] = inv[i+1] * (i+1) % p;
    for(int i = 0; i < len; i++) 
    {
        int tmp = ksm(ksm(jz[s],i),p-2);
        a[i] = C(m,i) * C(n,i*s) % p * jz[i*s] % p * tmp % p * ksm(m-i,n-i*s) % p;
        a[i] = a[i] * jz[i];
    }
    for(int i = 0, tmp = 1; i < len; i++, tmp *= -1) b[i] = (tmp * inv[i] + p) % p;
    reverse(a,a+len);
    int lim = 1, tim = 0;
    while(lim < (len<<1)) tim++, lim <<= 1;
    for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
    NTT(a,lim,1); NTT(b,lim,1);
    for(int i = 0; i < lim; i++) a[i] = a[i] * b[i] % p;
    NTT(a,lim,-1);
    reverse(a,a+len);
    for(int i = 0; i < len; i++) a[i] = a[i] * inv[i] % p;
    for(int i = 0; i < len; i++) ans = (ans + w[i] * a[i] % p) % p;
    printf("%lld\n",ans);
    return 0; 
}
posted @ 2021-03-15 22:37  genshy  阅读(89)  评论(0编辑  收藏  举报