cf1251 F. Red-White Fence NTT + 组合数学

传送门
比较好的一个理解多项式乘法的题。

首先,周长\(C = (红板长度 + 板子数目)\times 2\)

那么对于每一个查询\(Q\)即周长,可以得到选取的板子数目为\(\frac{Q}{2} - 红板长度\), 白板个数为\(\frac{Q}{2} - 红板长度 - 1\)

那么只需要求出对于每一个红板的情况,最后对于每一个查询,枚举在该红板下取白板数,求和即可。

现考虑对于每一个红板,求白板的情况。

首先,对于只出现一次的白板,有\(tot1\)个,那么取\(i\)个白板的情况为\(2^i\times C_{tot1}^{ i}\), 因为有两种可能,在左边或在右边。

对于出现次数大于等于2的白板,设有\(tot2\)个,那么取\(i\)个白板的情况为\(C_{2tot2}^{i}\),相同的白板就放在两侧,大于等于2的都等价于等于2的情况。

相当于就是\(ans[k] = f_i + g_{k - i}\)\(g_i\) = \(2^i\times C_{tot1}^{ i}\), \(g_i =C_{2tot2}^{i}\)

那么对于每一个红板长度,就要去求出所有的\(ans_i\),这就用多项式乘法就行了,这就相当于是多项式乘法的定义了。

其实有个地方不好理解,就是对于出现次数大于等于2的白板,如果你是成对地取,相当于是\(i\)是偶数时比较好理解,如果是取奇数个情况呢?那不就得乘2吗,因为两侧都可以放,但其实取奇数时是存在重复的,比如两个1的时候,我是\(C_2^1\),这时取重复了,但其实是\(C_1^1\),所以取奇数个时,奇数那部分得先除以2再乘以2,相当于不乘。

#include <bits/stdc++.h>
#define ll long long
#define CASE int Kase = 0; cin >> Kase; for(int kase = 1; kase <= Kase; kase++)
using namespace std;
template<typename T = long long> inline T read() {
    T s = 0, f = 1; char ch = getchar();
    while(!isdigit(ch)) {if(ch == '-') f = -1; ch = getchar();}
    while(isdigit(ch)) {s = (s << 3) + (s << 1) + ch - 48; ch = getchar();} 
    return s * f;
}
const int N = 3e5 + 5, MAXN = 3e5 + 5, MOD = 1e9 + 7, CM = 998244353, INF = 0x3f3f3f3f;
namespace NTT{
    ll ksm(ll a, ll b, ll p){ ll ans = 1; a %= p; while(b){ if(b & 1) ans = ans * a % p; a = a * a % p; b >>= 1;} return ans; }
    const int P = 998244353, G = 3, inv_G = 332748118; // p的原根和原根在p下的逆元
    const int N = 2e6 + 5; // n * 4
    int r[N], A[N], B[N], n, m;
    void NTT(int *a, int n, int op){
        for(int i = 0; i < n; i++)
            if(i < r[i]) swap(a[i], a[r[i]]);
        for(int mid = 1; mid < n; mid <<= 1){
            int x = ksm(op == 1 ? G: inv_G, (P - 1) / (mid << 1), P);
            for(int j = 0; j < n; j += (mid << 1)){
                int w = 1;
                for(int k = 0; k < mid; k++, w = 1ll * w * x % P){
                    int t1 = a[j + k], t2 = 1ll * w * a[j + k + mid] % P;
                    a[j + k] = (t1 + t2) % P;
                    a[j + k + mid] = (t1 - t2 + P) % P;
                }
            }
        }
        if(op == -1) {
            int inv = ksm(n, P - 2, P);
            for(int i = 0; i <= n; i++) A[i] = 1ll * A[i] * inv % P;
        }
    }
    void mul(int *a, int *b, int *c, int nn, int mm, int &k){
        int l = 0;
        n = nn, m = mm;
        for(m = n + m, n = 1; n <= m; n <<= 1, l++);
        for(int i = 0; i < n; i++) {
            r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
            A[i] = B[i] = 0;
        }
        for(int i = 0; i <= nn; i++) A[i] = a[i];
        for(int i = 0; i <= mm; i++) B[i] = b[i];
        NTT(A, n, 1); NTT(B, n, 1);
        for(int i = 0; i < n; i++) A[i] = 1ll * A[i] * B[i] % P;
        NTT(A, n, -1); k = m;
        for(int i = 0; i <= m; i++) c[i] = A[i];
    }
};
ll ksm(ll a, ll b, ll p){
    ll ans = 1; a %= p;
    while(b){
        if(b & 1) ans = ans * a % p;
        a = a * a % p;
        b >>= 1;
    }
    return ans;
}
namespace Combination{
    ll fac[MAXN], invfac[MAXN], mod;
    void init(int n, ll MOD){ // 线性求[1, n]的组合数和逆元
        fac[0] = 1; mod = MOD;
        for(int i = 1; i <= n; i++)
            fac[i] = fac[i - 1] * i % mod;
        invfac[n] = ksm(fac[n], mod - 2, mod);
        for(int i = n; i >= 1; i--)
            invfac[i - 1] = invfac[i] * i % mod;
    }
    ll C(ll n, ll m){
        return n >= m ? fac[n] * invfac[n - m] % mod * invfac[m] % mod: 0;
    }
}
int a[N], b[N], cnt[N], n, k;
int ans[N][6];
int aa[N], bb[N], cc[N];
void cal(int pos, int red) {
    for(int i = 1; i <= n; i++) cnt[a[i]] = 0;
    for(int i = 1; i <= n; i++) if(a[i] < red) cnt[a[i]]++;
    int one = 0, two = 0;
    for(int i = 1; i < red; i++) one += cnt[i] == 1, two += cnt[i] >= 2;
    two *= 2;
    for(int i = 0; i <= one; i++) aa[i] = 1ll * ksm(2, i, CM) * Combination::C(one, i) % CM;
    for(int i = 0; i <= two; i++) bb[i] = Combination::C(two, i);

    int three = 0;
    NTT::mul(aa, bb, cc, one, two, three);
    for(int i = 0; i <= three; i++) ans[i][pos] = cc[i];
}
void solve(int kase){
    Combination::init(N - 4, CM);
    n = read(), k = read();
    for(int i = 1; i <= n; i++) a[i] = read();
    for(int i = 1; i <= k; i++) b[i] = read();
    for(int i = 1; i <= k; i++) cal(i, b[i]);
    int q = read();
    for(int i = 1; i <= q; i++) {
        int Q = read();
        int res = 0;
        for(int j = 1; j <= k; j++) {
            int num = Q / 2 - b[j] - 1;
            if(num < 0) continue;
            res += ans[num][j];
            res %= CM;
        }
        printf("%d\n", res);
    }
}
const bool DUO = 0;
int main(){
    clock_t start, finish; double totaltime; start = clock();
    if(DUO) {CASE solve(kase);} else solve(1);
    finish = clock(); 
    #ifdef ONLINE_JUDGE
        return 0;
    #endif
    printf("\nTime: %lfms\n", (double)(finish - start) / CLOCKS_PER_SEC * 1000);
    return 0;
}
posted @ 2021-02-17 15:24  Emcikem  阅读(61)  评论(0编辑  收藏  举报