luogu P6078 [CEOI2004] Sweets(生成函数入门题)

我的生成函数只有幼儿园水平/kk
https://www.luogu.com.cn/problem/P6078
在这里插入图片描述
首先对于一种糖,它的生成函数可以写成这样,封闭形式就是
∑ j = 0 m i x j = 1 − x m i + 1 1 − x \large \sum\limits_{j=0}^{m_i}x^j=\frac{1-x^{m_i+1}}{1-x} j=0mixj=1x1xmi+1
然后把这些全部乘起来就是答案的对于每个糖果生成函数
∏ i = 1 n 1 − x m i + 1 1 − x \large \prod\limits_{i=1}^n\frac{1-x^{m_i+1}}{1-x} i=1n1x1xmi+1
暴力乘显然不行
考虑先把 1 1 − n \frac{1}{1-n} 1n1拉出来
∏ i = 1 n 1 − x m i + 1 1 − x = 1 ( 1 − x ) n ∏ i = 1 n ( 1 − x m i + 1 ) \large \prod\limits_{i=1}^n\frac{1-x^{m_i+1}}{1-x} \\=\frac{1}{(1-x)^n}\prod\limits_{i=1}^n(1-x^{m_i+1}) i=1n1x1xmi+1=(1x)n1i=1n(1xmi+1)
前面那个通过牛顿二项式定理推广到负数的情况
∑ i = 0 ∞ ( m + i i ) x i = 1 ( 1 − x ) m + 1 \sum\limits_{i=0}^{\infin}\binom{m+i}{i}x^i=\frac{1}{(1-x)^{m + 1}} i=0(im+i)xi=(1x)m+11
可得

∑ j = 0 ∞ ( n − 1 + j j ) x j ∏ i = 1 n ( 1 − x m i + 1 ) \large \sum\limits_{j=0}^{\infin}\binom{n-1+j}{j}x^j\prod\limits_{i=1}^n(1-x^{m_i+1}) j=0(jn1+j)xji=1n(1xmi+1)
答案显然可以表示为两个系数前缀和相减的结果,这里先只考虑一个 x b x^b xb的系数前缀和
考虑,那么这个次数对答案的贡献(生成函数对应项系数)可以写成

∑ j = 0 b − k ( n − 1 + j j ) × [ x k ] ∏ i = 1 n ( 1 − x m i + 1 ) \large \sum\limits_{j=0}^{b-k}\binom{n-1+j}{j}\times[x^k]\prod\limits_{i=1}^n(1-x^{m_i+1}) j=0bk(jn1+j)×[xk]i=1n(1xmi+1)
前面那个组合数求和可以化简,最后写成
( n + b − k n ) × [ x k ] ∏ i = 1 n ( 1 − x m i + 1 ) \binom{n+b-k}{n}\times[x^k]\prod\limits_{i=1}^n(1-x^{m_i+1}) (nn+bk)×[xk]i=1n(1xmi+1)

因为 n n n很小,所以对于右边那个直接大力dfs枚举每类糖果选或不选,然后大力乘上系数前缀和求和即可
到这里这道题的重点就做完了
后面是关于模数非质数,可能存在没有逆元的情况,具体的解决办法就是先把模数乘上除数,计算完之后再除回去得到答案
具体证明:
在这里插入图片描述

摘自巨佬Rui_R的blog

code:

#include<bits/stdc++.h>
#define mod 2004
#define ll long long
using namespace std;
ll fac = 1;
ll C(int n, int m) {
    ll MOD = mod * fac, ans = 1;
    for(int i = n - m + 1; i <= n; i ++) ans = ans * i % MOD;
    return (ans / fac) % mod;
}
ll ans;
int a[15], n, l, r;
void dfs(int x, int o, int mi, int lim) {
    if(mi > lim) return ;
    if(x > n) {
        ans += C(n + lim - mi, n) * o % mod; ans %= mod;
        return ;
    }
    dfs(x + 1, o, mi, lim), dfs(x + 1, mod - o, mi + a[x] + 1, lim);
}
int calc(int lim) {
    ans = 0; dfs(1, 1, 0, lim);
    return ans;
}
int main() {
    scanf("%d%d%d", &n, &l, &r);
    for(int i = 1; i <= n; i ++) {
        scanf("%d", &a[i]);
        fac *= i;
    }
    printf("%lld", (calc(r) - calc(l - 1) + mod) % mod);
    return 0;
}
posted @ 2021-08-08 22:40  lahlah  阅读(55)  评论(0编辑  收藏  举报