「解题报告」ARC158E All Pair Shortest Paths
LA 群说是大原题,很简单,我不会啊。
考虑 DP。\(0/1\) 太丑了,全换成字母。
设 \(x_i = x_{1, i}, y_i = x_{2, i}\)。
然后设 \(a_{i, X}\) 表示从 \((1, i)\) 到 \(X\) 的最短路径,\(b_{i, X}\) 表示从 \((2, i)\) 到 \(X\) 的最短路径。
有转移 \(a_{i, X} = \min\{a_{i + 1, X} + x_i, b_{i + 1} + x_i + y_i\}\), \(b_{i, X} = \min\{b_{i + 1, X} + y_i, a_{i + 1} + x_i + y_i\}\)。
考虑维护 \((a_{i, X}, b_{i, X})\) 的集合,每次将 \(i\) 转移至 \(i + 1\),这样我们如果能维护出这个集合与元素和,那么总和就容易求了。
设 \(S_i\) 为 \(i\) 的集合,那么有转移:
更进一步的,我们将 \(\min\) 拆开,那么前面的式子就是:
注意到我们只关心两者的和,即 \((a + b)\),而取值范围都与 \((a-b)\) 有关,那么我们不妨将维护的 \((a, b)\) 变成维护 \((a - b, a + b)\),那么就能变换成:
发现 \(b\) 的差值只与 \(a\) 有关,那么我们可以不维护 \(b\),只维护 \(a\) 与对应的个数。而这个式子是很好维护的,可以维护一个全局加的 tag,那么每次对于超出了 \([-x_i, y_i]\) 的 \(a\),我们就可以将其推平为对应值。而发现新加入的两个元素变成了 \((-y_i, 2x_i + y_i), (x_i, x_i + 2y_i)\),正好对应集合中的最小值和最大值,那么我们直接拿个双端队列维护这个集合即可。
复杂度 \(O(n)\)。
点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 200005, P = 998244353;
int n;
int x[MAXN], y[MAXN];
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", &x[i]);
}
for (int i = 1; i <= n; i++) {
scanf("%d", &y[i]);
}
long long ans = 0;
long long sum = 0;
long long add = 0;
deque<pair<long long, long long>> q;
for (int i = 1; i <= n; i++) {
int x = ::x[i], y = ::y[i];
sum = (sum + 2ll * (x + y) * (i - 1)) % P;
int fcnt = 1, bcnt = 1;
while (!q.empty()) {
long long a = q.front().first + add, cnt = q.front().second;
if (a >= -x) break;
q.pop_front();
sum = (sum + 1ll * (a + x) * cnt) % P;
fcnt += cnt;
}
while (!q.empty()) {
long long a = q.back().first + add, cnt = q.back().second;
if (a <= y) break;
q.pop_back();
sum = (sum + 1ll * (-a + y) * cnt) % P;
bcnt += cnt;
}
sum += 3ll * (x + y) % P;
add += x - y;
q.push_front({-y - add, fcnt});
q.push_back({x - add, bcnt});
ans = (ans + sum) % P;
}
ans *= 2;
for (int i = 1; i <= n; i++) ans -= 3ll * (x[i] + y[i]);
ans = (ans % P + P) % P;
printf("%lld\n", ans);
return 0;
}
upd. 被 \(\color{#505050}{\text{L}}\color{red}{\text{YinMX}}\) D 了,说我不会题解做法就不要写题解做法。
我还是太菜了,我确实不会题解做法,但是我赛时上确实没有想到别的做法。
但是有一个很显然的分治做法,我确实是废物没有想到。
考虑所有的路径一定不会出现弯路,删掉弯路一定更优。那么一个路径就可以从中点处拆开。考虑分治,每次 DP 出左边每个点到 \((mid, 1), (mid, 2)\) 的最短路与右面每个点到 \((mid + 1, 1), (mid + 1, 2)\) 的距离。把它们写作 \((a, b), (c, d)\) 的形式。
那么我们要求的就是 \(\sum \min\{a + b, c + d\}\)。
改写一下条件,\(a + b < c + d \Leftrightarrow a - c < d - b\)。
那么这实际上是一个二维偏序的形式,直接拿树状数组维护即可。复杂度 \(O(n \log^2 n)\)。
点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 400005, P = 998244353;
int n;
int x[MAXN], y[MAXN];
int ans;
long long a[MAXN][2], b[MAXN][2];
struct BinaryIndexTree {
#define lowbit(x) (x & (-x))
int a[MAXN];
void add(int d, int v) {
while (d <= 2 * n) {
a[d] = (a[d] + v) % P;
d += lowbit(d);
}
}
int query(int d) {
if (!d) return 0;
int ans = 0;
while (d) {
ans = (ans + a[d]) % P;
d -= lowbit(d);
}
return ans;
}
} bit1, bit2, bit3, bit4;
struct Data {
long long a, b, key;
bool side;
Data(long long a, long long b, bool side) : a(a), b(b), side(side), key(side ? b - a : a - b) {}
bool operator<(const Data &b) {
return key == b.key ? side > b.side : key < b.key;
}
};
void solve(int l, int r) {
if (l == r) {
ans = (ans + 3ll * (x[l] + y[l])) % P;
return;
}
int mid = (l + r) >> 1;
solve(l, mid), solve(mid + 1, r);
vector<Data> q;
a[mid][0] = x[mid], a[mid][1] = x[mid] + y[mid];
b[mid][0] = x[mid] + y[mid], b[mid][1] = y[mid];
for (int i = mid - 1; i >= l; i--) {
a[i][0] = min(a[i + 1][0] + x[i], a[i + 1][1] + x[i] + y[i]);
a[i][1] = min(a[i + 1][1] + y[i], a[i + 1][0] + x[i] + y[i]);
b[i][0] = min(b[i + 1][0] + x[i], b[i + 1][1] + x[i] + y[i]);
b[i][1] = min(b[i + 1][1] + y[i], b[i + 1][0] + x[i] + y[i]);
}
for (int i = l; i <= mid; i++) {
q.push_back({a[i][0], b[i][0], 0});
q.push_back({a[i][1], b[i][1], 0});
}
a[mid + 1][0] = x[mid + 1], a[mid + 1][1] = x[mid + 1] + y[mid + 1];
b[mid + 1][0] = x[mid + 1] + y[mid + 1], b[mid + 1][1] = y[mid + 1];
for (int i = mid + 2; i <= r; i++) {
a[i][0] = min(a[i - 1][0] + x[i], a[i - 1][1] + x[i] + y[i]);
a[i][1] = min(a[i - 1][1] + y[i], a[i - 1][0] + x[i] + y[i]);
b[i][0] = min(b[i - 1][0] + x[i], b[i - 1][1] + x[i] + y[i]);
b[i][1] = min(b[i - 1][1] + y[i], b[i - 1][0] + x[i] + y[i]);
}
for (int i = mid + 1; i <= r; i++) {
q.push_back({a[i][0], b[i][0], 1});
q.push_back({a[i][1], b[i][1], 1});
}
sort(q.begin(), q.end());
vector<int> vals;
for (auto p : q) vals.push_back(p.key);
vals.erase(unique(vals.begin(), vals.end()), vals.end());
for (auto &p : q) p.key = lower_bound(vals.begin(), vals.end(), p.key) - vals.begin() + 1;
for (auto p : q) {
if (!p.side) {
ans = (ans + 2ll * bit3.query(p.key)) % P;
ans = (ans + 2ll * p.b % P * bit4.query(p.key)) % P;
bit1.add(p.key, p.a % P);
bit2.add(p.key, 1);
} else {
ans = (ans + 2ll * bit1.query(p.key - 1)) % P;
ans = (ans + 2ll * p.a % P * bit2.query(p.key - 1)) % P;
bit3.add(p.key, p.b % P);
bit4.add(p.key, 1);
}
}
for (auto p : q) {
if (!p.side) {
bit1.add(p.key, -p.a % P);
bit2.add(p.key, -1);
} else {
bit3.add(p.key, -p.b % P);
bit4.add(p.key, -1);
}
}
}
int main() {
// freopen("03_rand_1_01.txt", "r", stdin);
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", &x[i]);
}
for (int i = 1; i <= n; i++) {
scanf("%d", &y[i]);
}
solve(1, n);
printf("%d\n", ans);
return 0;
}