luogu P5394 【模板】下降幂多项式乘法

https://www.luogu.com.cn/problem/P5394

之前一直以为是什么阴间东西,没有碰,现在菜知道原来是个挺naive的东西

我们把设点值的指数型生成函数为$$G(x)=\sum_{i=0}^{\infin} \frac{F(i)}{i!} x^i$$

\(\large =\sum\limits_{i=0}^{\infin}\frac{x^i}{i!} \sum\limits_j{f[j]i^{\underline{j}}}=\sum\limits_j{f[j]\sum\limits_{i=0}^{\infin}\frac{i^{\underline{j}}}{i!}x^i}\)

我们考虑后面那坨东西是什么

\[\sum\limits_{i=0}^{\infin}\frac{i^{\underline{j}}}{i!}x^i=\sum\limits_{i=0}^{\infin}\frac{1}{(i-j)!}x^i=x^j\sum\limits_{i=0}^{\infin}\frac{1}{i!}x^i=x^j e^x \]

带到上面

\[\large =e^x \sum\limits_j f[j] x^j \]

于是我们就得到了点值的生成函数\(G(x)=e^xF(x)\)
要注意,\(G(x)\)\(EGF\), \(F(x)\)\(OGF\)

所以点值的\(DFTf[i]= \sum\limits_{i=0} g[i]*i!\)

然后把点值乘起来,再除\(i!\),然后再乘\(e^{-x}\)就可以的\(IDFT\)回去啦

code:

#include<bits/stdc++.h>
#define mod 998244353
#define N 800050
using namespace std;
int add(int x, int y) { x += y;
    if(x >= mod) x -= mod;
    return x;
}
int sub(int x, int y) { x -= y;
    if(x < 0) x += mod;
    return x;
}
int mul(int x, int y) {
    return 1ll * x * y % mod;
}
int qpow(int x, int y) {
    int ret = 1;
    for(; y; y >>= 1, x = mul(x, x)) if(y & 1) ret = mul(ret, x);
    return ret;
}
int fac[N], ifac[N];
void init(int n) {
    fac[0] = 1;
    for(int i = 1; i <= n; i ++) fac[i] = mul(fac[i - 1], i);
    ifac[n] = qpow(fac[n], mod - 2);
    for(int i = n - 1; i >= 0; i --) ifac[i] = mul(ifac[i + 1], i + 1);
}
const int G = 3;
const int Ginv = qpow(3, mod - 2);
int rev[N];
void ntt(int *a, int n, int o) {
    for(int i = 1; i < n; i ++) if(i > rev[i]) swap(a[i], a[rev[i]]);
    for(int len = 2; len <= n; len <<= 1) {
        int w0 = qpow((o == 1)? G : Ginv, (mod - 1) / len);
        for(int j = 0; j < n; j += len) {
            int wn = 1;
            for(int k = j; k < j + (len >> 1); k ++, wn = mul(wn, w0)) {
                int X = a[k], Y = mul(wn, a[k + (len >> 1)]);
                a[k] = add(X, Y), a[k + (len >> 1)] = sub(X, Y);
            }
        }
    }
    int ninv = qpow(n, mod - 2);
    if(o == -1)
        for(int i = 0; i < n; i ++) a[i] = mul(a[i], ninv);
}
int a[N], b[N], ex[N], efx[N];
void DFT(int *a, int n) {
    ntt(a, n, 1);
    for(int i = 0; i < n; i ++) a[i] = mul(a[i], ex[i]);
    ntt(a, n, -1);
}
void IDFT(int *a, int n) {
    ntt(a, n, 1);
    for(int i = 0; i < n; i ++) a[i] = mul(a[i], efx[i]);
    ntt(a, n, -1);
}
int n, m;
int main() {
    scanf("%d%d", &n, &m); int lim = n + m;
    init(lim * 2);
    for(int i = 0; i <= n; i ++) scanf("%d", &a[i]);
    for(int i = 0; i <= m; i ++) scanf("%d", &b[i]);

    for(int i = 0; i <= lim; i ++) {
        efx[i] = ex[i] = ifac[i];
        if(i & 1) efx[i] = sub(0, efx[i]);
    }

    int len = 1;
    for(; len <= 2 * lim; len <<= 1);
    for(int i = 1; i < len; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * (len >> 1));

    ntt(ex, len, 1), ntt(efx, len, 1);
    DFT(a, len), DFT(b, len);
    for(int i = 0; i <= lim; i ++) a[i] = mul(a[i], b[i]), a[i] = mul(a[i], fac[i]);
    for(int i = lim + 1; i <= len; i ++) a[i] = 0;
    IDFT(a, len);
    for(int i = 0; i <= n + m; i ++) printf("%d ", a[i]);
    return 0;
}
posted @ 2021-12-21 21:43  lahlah  阅读(39)  评论(0编辑  收藏  举报