「解题报告」ARC129F Let's Play Tag

挺暴力的题。

考虑我们在移动过程中,实际上只关心什么时候转向。我们把转向的几个点设为 \(a_1, a_2, \cdots, a_k\),并令开头和结尾也算一次转向。

那么我们可以先手动模拟一下,发现每个 \(a_i\) 的贡献都很有规律。具体来说,对于某一种转向方案,答案等于 \(a_k + \sum_{i=1}^{k - 1} 4 \times 3^{k - i - 1} a_i\)。这个手动模拟一下即可得出。

那么发现这个贡献只与某一个数与它所在的位置有关,所以我们很容易能够将贡献拆开来计算。

具体的,就是枚举 \(a_1 \sim a_k\) 中有 \(x\) 个在左侧,\(a_i\) 在从左到右第 \(j\) 位(从右到左第 \(k-j\) 位),在枚举右边选择的方案等等,右边选择的方案需要枚举第一个数选左边/右边和最后一个数选左边/右边。反正大力列式子就能列出以下式子:

\[\frac{4}{3}\sum_{i=1}^{n-1}L_{i}\sum_{x=1}^{n}\left(4{m-1\choose x-1}+{m-1\choose x-2}+3{m-1\choose x}\right)\sum_{j=1}^{x-1}9^{j}{n-i-1\choose j-1}{i-1\choose x-j-1}+ \]

\[\sum_{x=1}^{n}{n-1\choose x-1}\left(5{m-1\choose x-1}+{m-1\choose x-2}+4{m-1\choose x}\right)L_{n} \]

这个系数是因为在枚举第一个数和最后一个数的位置时,\(3\) 的指数可能为 \(2j / 2j+1\)

然后换下求和指标,可以范德蒙德卷积消掉一个 \(\sum\),然后剩下的式子可以 NTT 优化。

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 800005, P = 998244353, G = 3;
int n, m;
int l[MAXN];
int fac[MAXN], inv[MAXN];
int qpow(int a, int b) {
    int ans = 1;
    while (b) {
        if (b & 1) ans = 1ll * ans * a % P;
        a = 1ll * a * a % P;
        b >>= 1;
    }
    return ans;
}
int r[MAXN];
const int GI = qpow(G, P - 2);
struct Polynomial {
    vector<int> a;
    int len;
    Polynomial(int len = 0) : len(len) { a.resize(len + 1); }
    void set(int len) { this->len = len, a.resize(len + 1); }
    int& operator[](int b) { return a[b]; }
    void ntt(int limit, bool rev) {
        set(limit);
        for (int i = 0; i < limit; i++)
            if (i < r[i]) swap(a[i], a[r[i]]);
        for (int mid = 1; mid < limit; mid <<= 1) {
            int step = qpow(rev ? GI : G, (P - 1) / (mid << 1));
            for (int l = 0; l < limit; l += (mid << 1)) {
                int w = 1;
                for (int i = 0; i < mid; i++, w = 1ll * w * step % P) {
                    int x = a[l + i], y = 1ll * w * a[l + i + mid] % P;
                    a[l + i] = (x + y) % P, a[l + i + mid] = (x - y + P) % P;
                }
            }
        }
        if (rev) {
            int inv = qpow(limit, P - 2);
            for (int i = 0; i < limit; i++)
                a[i] = 1ll * a[i] * inv % P;
        }
    }
    Polynomial operator*(Polynomial b) {
        Polynomial a = *this, c;
        int len = a.len + b.len;
        int limit = 1;
        while (limit <= len) limit <<= 1;
        for (int i = 0; i < limit; i++)
            r[i] = (r[i >> 1] >> 1) | ((i & 1) * limit >> 1);
        a.ntt(limit, false), b.ntt(limit, false);
        c.set(limit);
        for (int i = 0; i < limit; i++) c[i] = 1ll * a[i] * b[i] % P;
        c.ntt(limit, true);
        c.set(len);
        return c;
    }
    void print() {
        for (int i : a) printf("%d ", i);
        printf("\n");
    }
} a, b, c;
int C(int n, int m) {
    if (n < 0 || m < 0 || n < m) return 0;
    return 1ll * fac[n] * inv[m] % P * inv[n - m] % P;
}
int ans;
void calc() {
    int x = 4ll * qpow(3, P - 2) % P;
    a.set(n - 1);
    for (int i = 1; i < n; i++)
        a[i] = 1ll * l[i] * fac[n - i - 1] % P * fac[i + m - 2] % P;
    b.set(min(n, m) - 1);
    for (int i = 1; i < min(n, m); i++)
        b[i] = 1ll * qpow(9, i) * inv[i - 1] % P * inv[m - i - 1] % P;
    c = a * b;
    for (int i = 1; i <= n; i++)
        ans = (ans + 4ll * c[i] * inv[n - i] % P * inv[i - 1] % P * x) % P;
    
    b.set(min(n, m + 1) - 1);
    for (int i = 1; i < min(n, m + 1); i++)
        b[i] = 1ll * qpow(9, i) * inv[i - 1] % P * inv[m - i] % P;
    c = a * b;
    for (int i = 2; i <= n; i++)
        ans = (ans + 1ll * c[i] * inv[n - i] % P * inv[i - 2] % P * x) % P;
    
    b.set(min(n, m - 1) - 1);
    for (int i = 1; i < min(n, m - 1); i++)
        b[i] = 1ll * qpow(9, i) * inv[i - 1] % P * inv[m - i - 2] % P;
    c = a * b;
    for (int i = 1; i <= n; i++)
        ans = (ans + 3ll * c[i] * inv[n - i] % P * inv[i] % P * x) % P;
    
    for (int i = 1; i <= n; i++) {
        ans = (ans + 5ll * l[n] * C(n - 1, i - 1) % P * C(m - 1, i - 1)) % P;
        ans = (ans + 1ll * l[n] * C(n - 1, i - 1) % P * C(m - 1, i - 2)) % P;
        ans = (ans + 4ll * l[n] * C(n - 1, i - 1) % P * C(m - 1, i)) % P;
    }
}
int main() {
    scanf("%d%d", &n, &m);
    const int N = 500000;
    fac[0] = 1;
    for (int i = 1; i <= N; i++)
        fac[i] = 1ll * fac[i - 1] * i % P;
    inv[N] = qpow(fac[N], P - 2);
    for (int i = N; i >= 1; i--)
        inv[i - 1] = 1ll * inv[i] * i % P;
    for (int i = 1; i <= n; i++) {
        scanf("%d", &l[i]);
    }
    calc();
    for (int i = 1; i <= m; i++) {
        scanf("%d", &l[i]);
    }
    swap(n, m);
    calc();
    printf("%d\n", ans);
    return 0;
}
posted @ 2023-03-20 16:22  APJifengc  阅读(57)  评论(0编辑  收藏  举报