「解题报告」ARC158E All Pair Shortest Paths

LA 群说是大原题,很简单,我不会啊。

考虑 DP。\(0/1\) 太丑了,全换成字母。

\(x_i = x_{1, i}, y_i = x_{2, i}\)

然后设 \(a_{i, X}\) 表示从 \((1, i)\)\(X\) 的最短路径,\(b_{i, X}\) 表示从 \((2, i)\)\(X\) 的最短路径。

有转移 \(a_{i, X} = \min\{a_{i + 1, X} + x_i, b_{i + 1} + x_i + y_i\}\), \(b_{i, X} = \min\{b_{i + 1, X} + y_i, a_{i + 1} + x_i + y_i\}\)

考虑维护 \((a_{i, X}, b_{i, X})\) 的集合,每次将 \(i\) 转移至 \(i + 1\),这样我们如果能维护出这个集合与元素和,那么总和就容易求了。

\(S_i\)\(i\) 的集合,那么有转移:

\[S_i = \{(\min\{a + x_i, b + x_i + y_i\}, \min\{a + x_i + y_i, b + y_i\}) \mid (a, b) \in S_{i - 1}\} \cup \{(x_i, x_i + y_i), (x_i + y_i, y_i)\} \]

更进一步的,我们将 \(\min\) 拆开,那么前面的式子就是:

\[\begin{cases} (a + x_i, a + x_i + y_i)&(a-b < -x_i)\\ (a + x_i, b + y_i) &(-x_i \le a - b \le y_i)\\ (b + x_i + y_i, b + y_i) & (a - b > y_i) \end{cases} \]

注意到我们只关心两者的和,即 \((a + b)\),而取值范围都与 \((a-b)\) 有关,那么我们不妨将维护的 \((a, b)\) 变成维护 \((a - b, a + b)\),那么就能变换成:

\[\begin{cases} (-y_i, b + a + 2x_i + y_i)&(a < -x_i)\\ (a + x_i - y_i, b + x_i + y_i) &(-x_i \le a \le y_i)\\ (x_i, b - a + x_i + 2y_i) & (a > y_i) \end{cases} \]

发现 \(b\) 的差值只与 \(a\) 有关,那么我们可以不维护 \(b\),只维护 \(a\) 与对应的个数。而这个式子是很好维护的,可以维护一个全局加的 tag,那么每次对于超出了 \([-x_i, y_i]\)\(a\),我们就可以将其推平为对应值。而发现新加入的两个元素变成了 \((-y_i, 2x_i + y_i), (x_i, x_i + 2y_i)\),正好对应集合中的最小值和最大值,那么我们直接拿个双端队列维护这个集合即可。

复杂度 \(O(n)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 200005, P = 998244353;
int n;
int x[MAXN], y[MAXN];
int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &x[i]);
    }
    for (int i = 1; i <= n; i++) {
        scanf("%d", &y[i]);
    }
    long long ans = 0;
    long long sum = 0;
    long long add = 0;
    deque<pair<long long, long long>> q;
    for (int i = 1; i <= n; i++) {
        int x = ::x[i], y = ::y[i];
        sum = (sum + 2ll * (x + y) * (i - 1)) % P;
        int fcnt = 1, bcnt = 1;
        while (!q.empty()) {
            long long a = q.front().first + add, cnt = q.front().second;
            if (a >= -x) break;
            q.pop_front();
            sum = (sum + 1ll * (a + x) * cnt) % P;
            fcnt += cnt;
        }
        while (!q.empty()) {
            long long a = q.back().first + add, cnt = q.back().second;
            if (a <= y) break;
            q.pop_back();
            sum = (sum + 1ll * (-a + y) * cnt) % P;
            bcnt += cnt;
        }
        sum += 3ll * (x + y) % P;
        add += x - y;
        q.push_front({-y - add, fcnt});
        q.push_back({x - add, bcnt});
        ans = (ans + sum) % P;
    }
    ans *= 2;
    for (int i = 1; i <= n; i++) ans -= 3ll * (x[i] + y[i]);
    ans = (ans % P + P) % P;
    printf("%lld\n", ans);
    return 0;
}

upd. 被 \(\color{#505050}{\text{L}}\color{red}{\text{YinMX}}\) D 了,说我不会题解做法就不要写题解做法。

我还是太菜了,我确实不会题解做法,但是我赛时上确实没有想到别的做法。

但是有一个很显然的分治做法,我确实是废物没有想到。

考虑所有的路径一定不会出现弯路,删掉弯路一定更优。那么一个路径就可以从中点处拆开。考虑分治,每次 DP 出左边每个点到 \((mid, 1), (mid, 2)\) 的最短路与右面每个点到 \((mid + 1, 1), (mid + 1, 2)\) 的距离。把它们写作 \((a, b), (c, d)\) 的形式。

那么我们要求的就是 \(\sum \min\{a + b, c + d\}\)

改写一下条件,\(a + b < c + d \Leftrightarrow a - c < d - b\)

那么这实际上是一个二维偏序的形式,直接拿树状数组维护即可。复杂度 \(O(n \log^2 n)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 400005, P = 998244353;
int n;
int x[MAXN], y[MAXN];
int ans;
long long a[MAXN][2], b[MAXN][2];
struct BinaryIndexTree {
#define lowbit(x) (x & (-x))
    int a[MAXN];
    void add(int d, int v) {
        while (d <= 2 * n) {
            a[d] = (a[d] + v) % P;
            d += lowbit(d);
        }
    }
    int query(int d) {
        if (!d) return 0;
        int ans = 0;
        while (d) {
            ans = (ans + a[d]) % P;
            d -= lowbit(d);
        }
        return ans;
    }
} bit1, bit2, bit3, bit4;
struct Data {
    long long a, b, key;
    bool side;
    Data(long long a, long long b, bool side) : a(a), b(b), side(side), key(side ? b - a : a - b) {}
    bool operator<(const Data &b) {
        return key == b.key ? side > b.side : key < b.key;
    }
};
void solve(int l, int r) {
    if (l == r) {
        ans = (ans + 3ll * (x[l] + y[l])) % P;
        return;
    }
    int mid = (l + r) >> 1;
    solve(l, mid), solve(mid + 1, r);
    vector<Data> q;
    a[mid][0] = x[mid], a[mid][1] = x[mid] + y[mid];
    b[mid][0] = x[mid] + y[mid], b[mid][1] = y[mid];
    for (int i = mid - 1; i >= l; i--) {
        a[i][0] = min(a[i + 1][0] + x[i], a[i + 1][1] + x[i] + y[i]);
        a[i][1] = min(a[i + 1][1] + y[i], a[i + 1][0] + x[i] + y[i]);
        b[i][0] = min(b[i + 1][0] + x[i], b[i + 1][1] + x[i] + y[i]);
        b[i][1] = min(b[i + 1][1] + y[i], b[i + 1][0] + x[i] + y[i]);
    }
    for (int i = l; i <= mid; i++) {
        q.push_back({a[i][0], b[i][0], 0});
        q.push_back({a[i][1], b[i][1], 0});
    }
    a[mid + 1][0] = x[mid + 1], a[mid + 1][1] = x[mid + 1] + y[mid + 1];
    b[mid + 1][0] = x[mid + 1] + y[mid + 1], b[mid + 1][1] = y[mid + 1];
    for (int i = mid + 2; i <= r; i++) {
        a[i][0] = min(a[i - 1][0] + x[i], a[i - 1][1] + x[i] + y[i]);
        a[i][1] = min(a[i - 1][1] + y[i], a[i - 1][0] + x[i] + y[i]);
        b[i][0] = min(b[i - 1][0] + x[i], b[i - 1][1] + x[i] + y[i]);
        b[i][1] = min(b[i - 1][1] + y[i], b[i - 1][0] + x[i] + y[i]);
    }
    for (int i = mid + 1; i <= r; i++) {
        q.push_back({a[i][0], b[i][0], 1});
        q.push_back({a[i][1], b[i][1], 1});
    }
    sort(q.begin(), q.end());
    vector<int> vals;
    for (auto p : q) vals.push_back(p.key);
    vals.erase(unique(vals.begin(), vals.end()), vals.end());
    for (auto &p : q) p.key = lower_bound(vals.begin(), vals.end(), p.key) - vals.begin() + 1;
    for (auto p : q) {
        if (!p.side) {
            ans = (ans + 2ll * bit3.query(p.key)) % P;
            ans = (ans + 2ll * p.b % P * bit4.query(p.key)) % P;
            bit1.add(p.key, p.a % P);
            bit2.add(p.key, 1);
        } else {
            ans = (ans + 2ll * bit1.query(p.key - 1)) % P;
            ans = (ans + 2ll * p.a % P * bit2.query(p.key - 1)) % P;
            bit3.add(p.key, p.b % P);
            bit4.add(p.key, 1);
        }
    }
    for (auto p : q) {
        if (!p.side) {
            bit1.add(p.key, -p.a % P);
            bit2.add(p.key, -1);
        } else {
            bit3.add(p.key, -p.b % P);
            bit4.add(p.key, -1);
        }
    }
}
int main() {
    // freopen("03_rand_1_01.txt", "r", stdin);
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &x[i]);
    }
    for (int i = 1; i <= n; i++) {
        scanf("%d", &y[i]);
    }
    solve(1, n);
    printf("%d\n", ans);
    return 0;
}
posted @ 2023-03-13 11:17  APJifengc  阅读(105)  评论(5编辑  收藏  举报