luogu P6031 CF1278F Cards 加强版

https://www.luogu.com.cn/problem/P6031

首先发现每次洗牌都是独立的,所以概率 P = 1 m P=\frac{1}{m} P=m1
要计算的实际上是
∑ i = 0 n ( n i ) p i ( 1 − p ) n − i i k \sum_{i=0}^n\binom{n}{i}p^i(1-p)^{n-i}i^k i=0n(in)pi(1p)niik
按照套路拆开 i k i^k ik
∑ j = 0 k S ( k , j ) j ! ∑ i = j k ( i j ) ( n i ) p i ( 1 − p ) n − i \sum\limits_{j=0}^{k}S(k,j)j!\sum\limits_{i=j}^k\binom{i}{j}\binom{n}{i}p^i(1-p)^{n-i} j=0kS(k,j)j!i=jk(ji)(in)pi(1p)ni

用二项式定理把右边那个东西合并一下就可以得到

∑ j = 0 k S ( k , j ) j ! ( n j ) p j \sum_{j=0}^kS(k,j)j!\binom{n}{j}p^j j=0kS(k,j)j!(jn)pj
这样就可以做简单版了

显然可以用ntt优化到 O ( n l o g 2 n ) O(nlog_2n) O(nlog2n),但还是不够

把第二类斯特林数再拆开,推一家式子就可以得到线性的做法了

code:

#include<bits/stdc++.h>
#define ll long long
#define N 10000050
#define mod 998244353
using namespace std;
ll qpow(ll x, ll y) {
    ll ret = 1;
    for(; y; y >>= 1, x = x * x % mod) if(y & 1) ret = ret * x % mod;
    return ret;
}
int vis[N], prime[N], sz, k;
int idk[N];
void get(int n) {
    idk[1] = 1;
    for(int i = 2; i <= n; i ++) {
        if(!vis[i]) {
            idk[i] = qpow(i, k);
            prime[++ sz] = i;
        }
        for(int j = 1; j <= sz && i * prime[j] <= n; j ++) {
            vis[prime[j] * i] = 1; idk[prime[j] * i] = 1ll * idk[prime[j]] * idk[i] % mod;
            if(i % prime[j] == 0) break;
        }
    }
}
int inv[N], S[N];
void init(int n) {
    inv[1] = 1;
    for(int i = 2; i <= n; i ++) inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
}
int n, m;
int main() {
    scanf("%d%d%d", &n, &m, &k);
    get(k + 2), init(k + 2);

    ll P = qpow(m, mod - 2), Pm = 1, Cm = 1;
    S[0] = 1;
    for(int i = 1; i <= k; i ++) {
        Pm = Pm * (mod - P) % mod;
        Cm = Cm * ((n - k + i - 1 + mod) % mod) % mod * inv[i] % mod;
        S[i] = (1ll * (1 - P + mod) % mod * S[i - 1] % mod + Pm * Cm % mod) % mod;
    }
    //for(int i = 0; i <= k; i ++) printf("%lld ", S[i]); printf("\n");
    ll ans = 0; Cm = 1, Pm = 1;
    for(int i = 0; i <= k; i ++) {
        (ans += idk[i] * Cm % mod * Pm % mod * S[k - i] % mod) %= mod;
        Cm = Cm * ((n - i + mod) % mod) % mod * inv[i + 1] % mod;
        Pm = Pm * P % mod;
    }
    printf("%lld", ans);
    return 0;
}
posted @ 2021-12-13 20:03  lahlah  阅读(39)  评论(0编辑  收藏  举报