luogu P5339 [TJOI2019]唱、跳、rap和篮球 (容斥,指数型母函数,NTT)

https://www.luogu.com.cn/problem/P5339

要求不含1234的方案,反过来求含至少一个1234的方案。
钦定存在i个位置有1234,位置的方案是Cn-3i, i. 其他n-4i个位置的方案是多重集排列:
1 的生成函数$\sum_{i=0}^{num [1]} \frac{x^{i}}{i !}$1-4的生成函数相乘(NTT)得到cal(i),结果的n - 4i项就是其他位置的方案
钦定存在i个位置有1234的方案就是 f(i)=(Cn-3i, i)* cal(i)
钦定存在1个位置有1234的方案, 将恰好含有2个位置1234的方案重复计算了2次,减去钦定存在2个位置有1234的方案,再将钦定存在3个位置有1234的方案加回来...
答案就是:
\(f(0) - \sum_{i=1}^{m}(-1)^{i + 1}f(n)\)

#include<bits/stdc++.h>
using namespace std;
#define IOS ios::sync_with_stdio(false) ,cin.tie(0), cout.tie(0);
//#pragma GCC optimize(3,"Ofast","inline")
#define ll long long
//#define int long long
const int N = 5e2 + 5;
const int M = 1e3 + 5;
const int INF = 0x3f3f3f3f;
const ll LNF = 0x3f3f3f3f3f3f3f3f;
const int mod = 998244353;
const double PI = acos(-1.0);
const double eps = 1e-13;
const int MATN = 102;
const int p = 998244353, G = 3, Gi = 332748118;//这里的Gi是G的除法逆元

ll num[5]; int n;
ll F[M << 1];ll inv[M << 1];
ll qmi(ll m, ll k ) {
    ll res = 1;
    while(k) {
        if(k & 1) res = res *m % mod;
        m = m * m % mod;
        k >>= 1;
    }
    return res;
}
ll C(ll n, ll m ) {
    return F[n] * inv[m] % mod * inv[n - m] % mod;
}
void init(int n){
    F[0]=inv[0]= 1;
    for(int i=1;i<=n;i++)F[i]=F[i-1]*i%mod;//求阶乘
    inv[n] = qmi(F[n],mod - 2);//随便求一下n的逆元
    for(int i=n-1;i>=1;i--)inv[i]=inv[i+1]*(i+1)%mod;//由最后一个往前推
}
int limit = 1;//
int L;//二进制的位数
int RR[M << 2];
ll a[5][M << 2];
void NTT(ll *A, int type)
{
    for(int i = 0; i < limit; ++ i)
        if(i < RR[i])
            swap(A[i], A[RR[i]]);
    for(int mid = 1; mid < limit; mid <<= 1) {//原根代替单位根
        //ll wn = qpow(type == 1 ? G : Gi, (p - 1) / (mid << 1));
        ll wn = qmi(G, (p - 1) / (mid * 2));
        if(type == -1) wn = qmi(wn, p - 2); //wn的-1次方
        //如果超时了上面if这句话删掉,在下面的if(type == -1)里加上下面这个循环
        /*for (int i = 1; i < limit / 2; i ++)
        swap(A[i], A[limit - i]); */
        //逆变换则乘上逆元,因为我们算出来的公式中逆变换是(a^-ij),也就是(a^ij)的逆元
        for(int pos = 0; pos < limit; pos += mid * 2) {
            ll w = 1;
            for(int j = 0; j < mid; ++ j, w = (w * wn) % p) {
                int x = A[pos + j], y = w * A[pos + mid + j] % p;
                A[pos + j] = (x + y) % p;
                A[pos + j + mid] = (x - y + p) % p;

            }
        }
    }

    if(type == -1) {
        ll limit_inv = qmi(limit, mod - 2);//N的逆元(N是limit, 指的是2的整数幂)
        for(int i = 0; i < limit; ++ i)
            A[i] = (A[i] * limit_inv) % p;//NTT还是要除以n的,但是这里把除换成逆元了,inv就是n在模p意义下的逆元
    }
}
ll cal(int k) {
    for(limit = 1, L = 0; limit <= num[1] + num[2] + num[3] + num[4]; limit <<= 1) L ++ ;
    for(int i = 0; i < limit; ++ i) {
        RR[i] = (RR[i >> 1] >> 1) | ((i & 1) << (L - 1));
    }
    for ( int i = 1; i <= 4; ++ i ) {
        for ( int j = 0; j <= num[i] - k; ++ j ) {

            a[i][j] = inv[j];
        }
        for ( int j = num[i] - k + 1; j < limit; ++ j ) a[i][j] = 0;
    }
    for ( int i = 1; i <= 4; ++ i ) NTT(a[i], 1);
    for ( int i = 2; i <= 4; ++ i ) {
        for ( int j = 0; j < limit; ++ j ) {
            a[1][j] = a[1][j] * a[i][j] % mod;
        }
    }
    NTT(a[1], -1);
    return a[1][n - 4 * k] * F[n - 4 * k] % mod;
}
int main () {
    IOS
    init(2000);
    cin >> n;
    for ( int i = 1; i <= 4; ++ i ) cin >> num[i];
    sort(num + 1, num + 1 + 4);
    ll ans = 0;
    for ( int i = 0; i <= num[1] && i <= n / 4; ++ i ) {
        ll tmp = C(n - 3 * i, i) * cal(i) % mod;
        if(i & 1) ans = (ans -  tmp) % mod;
        else ans = (ans + tmp) % mod;
    }
    cout << (ans + mod) % mod << '\n';
    return 0;
}
posted @ 2022-10-17 10:59  qingyanng  阅读(25)  评论(0编辑  收藏  举报