Luogu 2000 拯救世界

胡小兔的博客那里过来的,简单记一下生成函数。

生成函数

数列$\{1, 1, 1, 1, \cdots\}$的生成函数是$f(x) = 1 + x + x^2 + x^3 + \cdots$,根据等比数列求和公式,可以得到$f(x) = \frac{1}{1 - x}$。

把两边分别平方,得到

$$\frac{1}{(1 - x)^2} = (1 + x + x^2 + x^3 + x^4 + \cdots)^2 = 1 + 2x + 3x^2 + 4x^3 + \cdots$$

相当于数列$\{1, 2, 3, 4, 5, \cdots \}$的生成函数。

两边三次方,得到

$$\frac{1}{(1 - x)^3} = 1 + 3x + 6x^2 + 10x^3 + \cdots$$

发现数列$\sum_{i = 0}^{\infty}\binom{i + k - 1}{k - 1}x^i$的生成函数是$\frac{1}{(1 - x)^k}$。

而数列$\sum_{i = 0}^{\infty}x^i[i \mod k == 0]$的生成函数是$\frac{1}{1 - x^k}$。

本题解法

把限制条件看成数列,第一个限制条件相当于数列$\{1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, \cdots \}$,而第二个限制条件则相当于有限长数列$\{1, 1, 1, 1, 1, 1, 1, 1, 1\}$,……,所有的限制条件都可以这样子类推出来。

把这个数列写成生成函数,第个限制条件的第$i$项系数可以看成是这一项取$i$个的方法,把这十个生成函数乘起来之后得到的生成函数的第$i$项就相当于一共取了$i$项的方案数。

最后乘起来得到了$\frac{1}{(1 - x)^5}$,对应了数列$\sum_{i = 0}^{\infty}\binom{i + 4}{4}x^i$的第$n$项,即$\frac{(n + 1)(n + 2)(n + 3)(n + 4)}{24}$。

然后就是一个高精了,因为$n$的位数很多,所以乘法的时候需要写$fft$。

Code:

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <cmath>
using namespace std;
typedef double db;

const int N = 2e6 + 5;
const db Pi = acos(-1.0);

int lim, pos[N];
char str[N];

struct Cpx {
    db x, y;
    
    inline Cpx (db _x = 0, db _y = 0) {
        x = _x, y = _y;
    }
    
    friend Cpx operator + (const Cpx u, const Cpx v) {
        return Cpx(u.x + v.x, u.y + v.y);
    }
    
    friend Cpx operator - (const Cpx u, const Cpx v) {
        return Cpx(u.x - v.x, u.y - v.y);
    }
    
    friend Cpx operator * (const Cpx u, const Cpx v) {
        return Cpx(u.x * v.x - u.y * v.y, u.x * v.y + v.x * u.y);
    }
    
} a[N], b[N];

inline void prework(int len) {
    int l = 0;
    for (lim = 1; lim < len; lim <<= 1, ++l);
    for (int i = 0; i < lim; i++)
        pos[i] = (pos[i >> 1] >> 1) | ((i & 1) << (l - 1));
}

inline void fft(Cpx *c, int opt) {
    for (int i = 0; i < lim; i++)
        if (i < pos[i]) swap(c[i], c[pos[i]]);
    for (int i = 1; i < lim; i <<= 1) {
        Cpx wn(cos(Pi / i), opt * sin(Pi / i));
        for (int len = i << 1, j = 0; j < lim; j += len) {
            Cpx w(1, 0);
            for (int k = 0; k < i; k++, w = w * wn) {
                Cpx x = c[j + k], y = w * c[j + k + i];
                c[j + k] = x + y, c[j + k + i] = x - y;
            }
        }
    }
}

struct BigInt {
    int len, s[N];
    
    inline void init() {
        len = 0;
        memset(s, 0, sizeof(s));
    }
    
    inline void readIn() {
        scanf("%s", str);
        len = strlen(str);
        for (int i = 0; i < len; i++) s[i] = str[len - i - 1] - '0';
    }
    
    inline void print() {
        if (!len) putchar('0');
        else {
            for (int i = len - 1; i >= 0; i--) printf("%d", s[i]);
        }
        printf("\n");
    }
    
    friend BigInt operator + (const BigInt &x, const int &y) {
        BigInt res = x;
        res.s[0] += y;
        for (int i = 0; i < res.len; i++) {
            if (i == res.len - 1) {
                if (res.s[i] >= 10) ++res.len;
                else break;
            }
            res.s[i + 1] += res.s[i] / 10, res.s[i] %= 10;
        }
        for (; res.s[res.len - 1] == 0 && res.len > 0; --res.len);
        return res;
    }
    
    friend BigInt operator * (const BigInt &x, const BigInt &y) {
        prework(x.len + y.len - 1);
        for (int i = 0; i < lim; i++) a[i] = b[i] = Cpx(0, 0);
        for (int i = 0; i < x.len; i++) a[i].x = x.s[i];
        for (int i = 0; i < y.len; i++) b[i].x = y.s[i];
        fft(a, 1), fft(b, 1);
        for (int i = 0; i < lim; i++) a[i] = a[i] * b[i];
        fft(a, -1);
        
        BigInt res;
        res.init();
        res.len = x.len + y.len - 1;
        for (int i = 0; i < res.len; i++) res.s[i] = int(a[i].x / lim + 0.5);
        
        for (int i = 0; i < res.len; i++) {
            if (i == res.len - 1) {
                if (res.s[i] >= 10) ++res.len;
                else break;
            }
            res.s[i + 1] += res.s[i] / 10, res.s[i] %= 10;
        }
        for (; res.s[res.len - 1] == 0 && res.len > 0; --res.len);
        
        return res;
    }
    
    friend BigInt operator / (const BigInt &x, const int y) {
        BigInt res;
        res.init(); 
        int rest = 0;
        for (int i = x.len - 1; i >= 0; i--) {
            rest = rest * 10 + x.s[i];
            if (rest >= y) res.s[res.len++] = rest / y, rest %= y;
            else if (res.len) res.s[res.len++] = 0;
        }
        for (int i = 0; i < res.len / 2; i++) swap(res.s[i], res.s[res.len - 1 - i]); 
        return res;
    }
    
} n;

int main() {
    #ifndef ONLINE_JUDGE
        freopen("Sample.txt", "r", stdin);
    #endif
    
    n.readIn();
//    n.print();
    
    BigInt n1 = n + 1, n2 = n + 2, n3 = n + 3, n4 = n + 4;
//    n1.print(), n2.print(), n3.print(), n4.print();
    
    BigInt ans = n1 * n2 * n3 * n4;
//    ans.print();
    
    ans = ans / 24;
    ans.print();
    return 0;
}
View Code

 

posted @ 2019-01-16 08:21  CzxingcHen  阅读(168)  评论(0编辑  收藏  举报