luogu P5162 WD与积木

https://www.luogu.com.cn/problem/P5162

考虑问题可以转换成把 n n n个有标号小球放进 m m m个有标号的盒子里,盒子不能为空

因为盒子是有标号的,所以不能直接exp
考虑一个盒子的指数型生成函数
F ( x ) = e x − 1 F(x)=e^x-1 F(x)=ex1
枚举几个盒子,方案数的生成函数即为
∑ k = 0 F k ( x ) = 1 1 − F ( x ) = 1 2 − e x \sum_{k=0}F^k(x)=\frac{1}{1-F(x)}=\frac{1}{2-e^x} k=0Fk(x)=1F(x)1=2ex1
然后再算层数总和
∑ k = 0 k F k ( x ) \sum_{k=0}kF^k(x) k=0kFk(x)

乍一看这个式子和求导很像,提取一个 F ( x ) F(x) F(x)出来得到
F ( x ) ∑ k = 0 k F k − 1 ( x ) = F ( x ) ( ∑ k = 0 F k ( x ) ) ′ = F ( x ) ( 1 2 − e x ) ′    = F ( x ) 1 ( 2 − e x ) 2 = e x − 1 ( 2 − e x ) 2 F(x)\sum_{k=0}kF^{k-1}(x)=F(x)(\sum_{k=0}F^k(x))'=F(x)(\frac{1}{2-e^x})'\\ ~~ \\=F(x)\frac{1}{(2-e^x)^2}=\frac{e^x-1}{(2-e^x)^2} F(x)k=0kFk1(x)=F(x)(k=0Fk(x))=F(x)(2ex1)  =F(x)(2ex)21=(2ex)2ex1

多项式快速幂+求逆即可
code:

#include<bits/stdc++.h>
#define int long long
#define mod 998244353
#define G 3
#define N 800005
using namespace std;
int qpow(int x, int y){
    int ret = 1;
    for(; y; y >>= 1, x = x * x % mod) if(y & 1) ret = ret * x % mod;
    return ret;
}
int rev[N], G_inv, len_inv;
void ntt(int *a, int len, int o){
    len_inv = qpow(len, mod - 2);
    for(int i = 0; i <= len; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i&1) * len >> 1);
    for(int i = 0; i <= len; i ++) if(i < rev[i]) swap(a[i], a[rev[i]]);
    for(int i = 2; i <= len; i <<= 1){
        int wn = qpow((o == 1)? G:G_inv, (mod - 1) / i);
        for(int j = 0, p = i / 2; j + i - 1 <= len; j += i){
            int w0 = 1;
            for(int k = j; k < j + p; k ++, w0 = w0 * wn % mod){
                int X = a[k], Y = w0 * a[k + p] % mod;
                a[k] = (X + Y) % mod;
                a[k + p] = (X - Y + mod) % mod;
            }
        }
    }
    if(o == -1)
        for(int i = 0; i <= len; i ++) a[i] = a[i] * len_inv % mod;
}
int c[N];
void inv(int *a, int *b, int sz){
    if(sz == 0) {b[0] = qpow(a[0], mod - 2); return;}
    inv(a, b, sz / 2);
    int len = 1;
    for(; len <= sz + sz; len <<= 1);
    for(int i = 0; i <= sz; i ++) c[i] = a[i];
    for(int i = sz + 1; i <= len; i ++) c[i] = 0;
    ntt(c, len, 1), ntt(b, len, 1);
    for(int i = 0; i <= len; i ++) b[i] = (b[i] * 2 % mod - b[i] * b[i] % mod * c[i] % mod + mod) % mod;
    ntt(b, len, -1);
    for(int i = sz + 1; i <= len; i ++) b[i] = 0;
}
void qiudao(int *a, int sz) {
    for(int i = 0; i < sz; i ++) a[i] = a[i + 1] * (i + 1) % mod;
    a[sz] = 0;
}
void jifen(int *a, int sz) {
    for(int i = sz; i >= 1; i --) a[i] = a[i - 1] * qpow(i, mod - 2) % mod;
    a[0] = 0;
}
int Ad[N], An[N];
void ln(int *A, int n) {
    for(int i = 0; i <= n; i ++) Ad[i] = A[i];
    qiudao(Ad, n);
    inv(A, An, n);
    int len = 1;
    for(; len <= n + n;) len <<= 1;
    ntt(Ad, len, 1), ntt(An, len, 1);
    for(int i = 0; i <= len; i ++) Ad[i] = Ad[i] * An[i] % mod;
    ntt(Ad, len, -1);
    jifen(Ad, n);
    for(int i = 0; i <= n; i ++) A[i] = Ad[i];
    for(int i = 0; i <= len; i ++) An[i] = Ad[i] = 0;
}
int fln[N];
void exp(int *a, int *b, int n) {
    if(n == 0) {b[0] = 1; return;}
    exp(a, b, n / 2);
    for(int i = 0; i <= n; i ++) fln[i] = b[i]; ln(fln, n);
    fln[0] = 1;
    for(int i = 1; i <= n; i ++) fln[i] = (a[i] - fln[i] + mod ) % mod;
    int len = 1;
    for(; len <= n + n;) len <<= 1;
    ntt(b, len, 1), ntt(fln, len, 1);
    for(int i = 0; i <= len; i ++) b[i] = b[i] * fln[i] % mod;
    ntt(b, len, -1);
    for(int i = 0; i <= len; i ++) fln[i] = 0;
}
int a[N], b[N], n, m, fac[N], ifac[N], ib[N], fm[N];
void init(int n) {
    fac[0] = 1;
    for(int i = 1; i <= n; i ++) fac[i] = fac[i - 1] * i % mod;
    ifac[n] = qpow(fac[n], mod - 2);
    for(int i = n - 1; i >= 0; i --) ifac[i] = 1ll * ifac[i + 1] * (i + 1) % mod;
}
signed main(){
    init(N - 10);
    G_inv = qpow(G, mod - 2);
    n = 100000;
    for(int i = 1; i <= n; i ++) a[i] = ifac[i];
    
    b[0] = 2;
    for(int i = 0; i <= n; i ++) b[i] = (b[i] + mod - ifac[i]) % mod;
    
 //   for(int i = 0; i <= 20; i ++) printf("%lld ", b[i]); printf("\n");
    
    inv(b, ib, n);
    
    for(int i = 0; i <= n; i ++) fm[i] = qpow(ib[i] * fac[i] % mod, mod - 2);
    
    int len = 1;
    for(; len <= n + n;) len <<= 1;
    ntt(ib, len, 1);
    for(int i = 0; i <= len; i ++) ib[i] = ib[i] * ib[i] % mod;
    ntt(ib, len, -1);
    for(int i = n + 1; i <= len; i ++) ib[i] = 0;
    
    
    
    ntt(a, len, 1), ntt(ib, len, 1);
    for(int i = 0; i <= len; i ++) a[i] = a[i] * ib[i] % mod;
    ntt(a, len, -1);
    
    int t;
    scanf("%lld", &t);
    while(t --) {
        scanf("%lld", &n);
        printf("%lld\n", fac[n] * a[n] % mod * fm[n] % mod);
    }
    return 0;
}
posted @ 2021-08-23 11:26  lahlah  阅读(40)  评论(0编辑  收藏  举报