CSP-S2022 题解
A. 假期计划(holiday)
首先我们可以跑 \(n\) 次 bfs,预处理出哪些景点之间可以转 \(k\) 次车到达。
然后我们设路径是 \(\text{Home} \to A \to B \to C \to D \to \text{Home}\)。我们对于每个点 \(X\),预处理 \(\text{Home} \to Y \to X\) 路径中 \(a_X + a_Y\) 前三大的值,以及此时对应的 \(Y\)。然后我们枚举 \(A, B\),然后暴力 \(3 \times 3\) 枚举 \(A\) 和 \(B\) 的前三大。可以证明答案一定在前三大内。
考场代码
#include <bits/stdc++.h>
typedef long long LL;
typedef __int128_t Lint;
const int N = 2500 + 5;
const int M = 1e4 + 5;
const int INF = 0x3f3f3f3f;
int n, m, K;
LL a[N];
struct Edge { int to, nxt; } edge[M << 1];
int head[N];
void add_edge(int u, int v) { static int k = 1; edge[k] = (Edge){v, head[u]}, head[u] = k++; }
int dis[N][N];
bool vis[N];
std::queue<int> q;
void bfs(int st) {
while(!q.empty()) q.pop();
for(int i = 1; i <= n; i++) vis[i] = false, dis[st][i] = INF;
dis[st][st] = 0, vis[st] = true, q.push(st);
while(!q.empty()) {
int u = q.front();
q.pop();
for(int i = head[u]; i; i = edge[i].nxt) {
int v = edge[i].to;
if(vis[v]) continue;
dis[st][v] = dis[st][u] + 1, vis[v] = true, q.push(v);
}
}
}
bool e[N][N];
int r[N][3];
void print(Lint x) {
char stk[200]; int top = 0;
if(x == 0) stk[top++] = '0';
while(x) stk[top++] = x % 10 + '0', x /= 10;
for(top--; top >= 0; top--) putchar(stk[top]);
}
int main() {
freopen("holiday.in", "r", stdin);
freopen("holiday.out", "w", stdout);
scanf("%d%d%d", &n, &m, &K);
for(int i = 2; i <= n; i++) scanf("%lld", &a[i]);
for(int i = 1; i <= m; i++) { int u, v; scanf("%d%d", &u, &v); add_edge(u, v), add_edge(v, u); }
for(int i = 1; i <= n; i++) bfs(i);
for(int i = 1; i <= n; i++) for(int j = 1; j <= n; j++) e[i][j] = (i == j ? false : dis[i][j] <= K + 1);
for(int i = 1; i <= n; i++) {
for(int j = 1; j <= n; j++) if(e[1][j] && e[j][i] && (r[i][0] == 0 || a[r[i][0]] < a[j])) r[i][0] = j;
for(int j = 1; j <= n; j++) if(e[1][j] && e[j][i] && j != r[i][0] && (r[i][1] == 0 || a[r[i][1]] < a[j])) r[i][1] = j;
for(int j = 1; j <= n; j++) if(e[1][j] && e[j][i] && j != r[i][0] && j != r[i][1] && (r[i][2] == 0 || a[r[i][2]] < a[j])) r[i][2] = j;
}
Lint ans = 0;
for(int i = 1; i <= n; i++) for(int j = 1; j <= n; j++) if(e[i][j]) {
int x1 = r[i][0], x2 = r[i][1], y1 = r[j][0], y2 = r[j][1];
if(x1 == j) x1 = r[i][1], x2 = r[i][2];
else if(x2 == j) x2 = r[i][2];
if(y1 == i) y1 = r[j][1], y2 = r[j][2];
else if(y2 == i) y2 = r[j][2];
if(x1 != y1) {
if(x1 && y1) ans = std::max(ans, (Lint)a[i] + a[j] + a[x1] + a[y1]);
} else {
if(x1 && y2) ans = std::max(ans, (Lint)a[i] + a[j] + a[x1] + a[y2]);
if(x2 && y1) ans = std::max(ans, (Lint)a[i] + a[j] + a[x2] + a[y1]);
}
print(ans), printf("\n");
return 0;
}
B. 策略游戏(game)
考虑如果小 L 已经选了一个正数,那么小 Q 一定会选一个最小值;反之则小 Q 一定会选一个最大值。
所以如果小 L 已经决定要选一个正数,那么他相当于已经知道小 Q 会选哪个数了。如果小 Q 将要选的是一个正数,那么小 L 一定会选择最大的正数,否则他会选最小的正数。反过来,如果小 L 已经决定要选一个负数,也是一样的道理。由于最开始的选择权在小 L,所以他一定是在正数和负数的解中取最大的一个。
至于 \(0\),我们现在一直都没考虑过它,但其实我们可以随便将他归到正数或者负数。
然后就做完了,好水的 T2 啊。
考场代码
#include <bits/stdc++.h>
typedef long long LL;
const int N = 1e5 + 5;
const LL LLINF = 0x3f3f3f3f3f3f3f3fLL;
const int NONE = 1e9 + 1;
int n, m, Q;
int a[N], b[N];
struct ST {
int lg[N], mn[N][21], mx[N][21];
int calcmin(int x, int y) { return x == NONE || y == NONE ? (x == NONE ? y : x) : std::min(x, y); }
int calcmax(int x, int y) { return x == NONE || y == NONE ? (x == NONE ? y : x) : std::max(x, y); }
void preprocess(int *arr, int bd, int type) {
lg[0] = -1;
for(int i = 1; i <= bd; i++) lg[i] = lg[i >> 1] + 1;
for(int i = 1; i <= bd; i++) mn[i][0] = mx[i][0] = (type == 2 ? arr[i] : (type == 1 ? (arr[i] >= 0 ? arr[i] : NONE) : (arr[i] <= 0 ? arr[i] : NONE)));
for(int j = 1; j <= 20; j++)
for(int i = 1; i + (1 << j) - 1 <= bd; i++) {
mn[i][j] = calcmin(mn[i][j - 1], mn[i + (1 << (j - 1))][j - 1]);
mx[i][j] = calcmax(mx[i][j - 1], mx[i + (1 << (j - 1))][j - 1]);
}
}
int getmin(int l, int r) {
int k = lg[r - l + 1];
return calcmin(mn[l][k], mn[r - (1 << k) + 1][k]);
}
int getmax(int l, int r) {
int k = lg[r - l + 1];
return calcmax(mx[l][k], mx[r - (1 << k) + 1][k]);
}
} sta[2], stb;
int main() {
freopen("game.in", "r", stdin);
freopen("game.out", "w", stdout);
scanf("%d%d%d", &n, &m, &Q);
for(int i = 1; i <= n; i++) scanf("%d", &a[i]);
for(int i = 1; i <= m; i++) scanf("%d", &b[i]);
sta[0].preprocess(a, n, 0), sta[1].preprocess(a, n, 1), stb.preprocess(b, m, 2);
while(Q--) {
int l1, r1, l2, r2;
scanf("%d%d%d%d", &l1, &r1, &l2, &r2);
int mn2 = stb.getmin(l2, r2), mx2 = stb.getmax(l2, r2);
LL ans = -LLINF;
if(mn2 >= 0 && sta[1].getmax(l1, r1) != NONE) ans = std::max(ans, (LL)sta[1].getmax(l1, r1) * mn2);
if(mn2 <= 0 && sta[1].getmin(l1, r1) != NONE) ans = std::max(ans, (LL)sta[1].getmin(l1, r1) * mn2);
if(mx2 >= 0 && sta[0].getmax(l1, r1) != NONE) ans = std::max(ans, (LL)sta[0].getmax(l1, r1) * mx2);
if(mx2 <= 0 && sta[0].getmin(l1, r1) != NONE) ans = std::max(ans, (LL)sta[0].getmin(l1, r1) * mx2);
printf("%lld\n", ans);
}
return 0;
}
C. 星战(galaxy)
首先发现一个结论:输出 YES
当且仅当每个点都有恰好一条出边。
证明也很显然。充分性:对于每个点,我们考虑一直沿着唯一的出边走,那么最坏情况下走 \(n - 1\) 条边就能走完 \(n\) 个点(否则一定有重复的点),所以走 \(n\) 条边一定能走完所有点,并且会有重复,即出现了环。必要性:根据条件,边数等于 \(n\),如果有点有多条出边,那么一定会出现有的点没有出边,不符合题意。
然后我们显然可以维护边数,但是怎么维护每个点只有一条出边这个信息呢?根号算法应该是过不了的,难道要用一些 nb 的数据结构?其实不用,我们考虑直接哈希。显然我们只需要有每 \(n\) 条边,而且每个点各一条出边,至于是哪 \(n\) 条,我们不关心。所以哈希应该和边无关,我们可以考虑只哈希每条边的起点。然后,这 \(n\) 个点的顺序其实也是无关紧要的。这就让我们想到了 XOR Hashing 或者 Sum Hashing。考虑对于每次操作维护边的起点的 XOR/SUM,但是这样很容易哈希冲突。所以我们给每个点随机赋一个很大的权值,就可以有效避免哈希冲突。这个很好维护。
挺巧妙的一道题。
赛后代码
#include <bits/stdc++.h>
typedef long long LL;
const int N = 5e5 + 5;
int n, m, Q;
int ind[N], d[N];
int cnt = 0;
LL bigrand() { return (LL)rand() << 31 | rand(); } // linux 下 rand() 的范围是 0~2^31-1
LL a[N], b[N], c[N];
int main() {
freopen("galaxy.in", "r", stdin);
freopen("galaxy.out", "w", stdout);
scanf("%d%d", &n, &m);
cnt = m;
LL val = 0, xor_all = 0;
for(int i = 1; i <= n; i++) a[i] = bigrand(), xor_all ^= a[i];
for(int i = 1; i <= m; i++) { int u, v; scanf("%d%d", &u, &v); ind[v]++, b[v] ^= a[u], val ^= a[u]; }
for(int i = 1; i <= n; i++) d[i] = ind[i], c[i] = b[i];
scanf("%d", &Q);
while(Q--) {
int t, x, y;
scanf("%d", &t);
if(t == 1) scanf("%d%d", &x, &y), cnt--, d[y]--, val ^= a[x], c[y] ^= a[x];
else if(t == 2) scanf("%d", &x), cnt -= d[x], d[x] = 0, val ^= c[x], c[x] = 0;
else if(t == 3) scanf("%d%d", &x, &y), cnt++, d[y]++, val ^= a[x], c[y] ^= a[x];
else if(t == 4) scanf("%d", &x), cnt += ind[x] - d[x], d[x] = ind[x], val ^= b[x] ^ c[x], c[x] = b[x];
puts(cnt == n && val == xor_all ? "YES" : "NO");
}
return 0;
}
D. 数据传输(transmit)
首先我们考虑哪些点能够出现在答案中。对于一个 \(u\) 到 \(v\) 的询问,显然 \(u\) 到 \(v\) 路径上的点是有可能出现在答案中的。然后和路径的距离为 \(1\) 的点也是有可能出现在答案中的,这只在 \(K = 3\) 时出现,样例 2 就有这种情况。那么和路径距离为 \(2\) 或者更大的点有没有可能呢?其实是不可能的,因为如果我们跳到某一个和路径距离为 \(2\) 的点,就必须还要跳回来,这样显然不如直接不经过这个点优秀。
所以我们其实可以直接在 \(u\) 到 \(v\) 的路径上进行 DP:设 \(f(i, j)\) 表示考虑了路径上前 \(i\) 个点,上一台需要计算代价的机器距离当前机器已经 \(j(j < K)\) 根网线了(\(K = 0\) 表示当前机器需要计算代价),传输的最小代价。转移也很显然:
其中 \(i'\) 表示路径上的上一个点,\(son(u)\) 表示 \(u\) 的权值最小的儿子的编号,第二行的转移是跳到了 \(i\) 的一个儿子的情况。
初值就是 \(f(0, j) = \begin{cases}0, &j = K - 1\\+\infty, &0 \le j < K - 1\end{cases}\),答案是 \(f(m, 0)\)。
但是这样是 \(O(n^2)\) 的。怎么办呢?虽然说每个询问的链都不一样,但是每个点的转移都是一样的。这让我们想到了矩阵乘法。所以如果我们可以预处理每个点的转移矩阵,那转移相当于是路径乘积。这个可以用倍增解决,不必用树剖。注意矩阵没有交换律,所以需要预处理从上往下乘和从下往上乘。
还有一种例外,就是走到了 \(u, v\) 的 \(lca\) 的父亲,这个相当于是路径上多了两个点(\(\cdots \times trans_{lca} \times \cdots\) 变成 \(\cdots \times trans_{lca} \times trans_{fa(lca)} \times trans_{lca} \times \cdots\)),也可以解决。
但是怎么把转移写成矩阵呢?我们发现,式子里要求 \(\min\) 和 \(\operatorname{sum}\),但是不需要乘积,这是一个套路,我们可以重新定义矩阵乘法,把 \(\min\limits_{1 \le j' \le K} \{f(i - 1, j')\} + v_i\) 写成 \(\min\limits_{1 \le j' \le K} \{f(i - 1, j') + v_i\}\),然后把原来的“乘”定义为加,原来的“加”定义为求 \(\min\)。换句话说就是 \(C = A \times B \iff C_{i, j} = \min\limits_{k}\{A_{i, k} + B_{k, j}\}\)。
那么我们就可以写出矩阵了:
\(K = 1\) 时:
\(K = 2\) 时:
\(K = 3\) 时:
然后就做完了,感觉比较板,比 T3 简单。
另外,qiuzx 有非矩阵做法并且表示出对矩阵做法的唾弃,大家可以去膜拜围观。
赛后代码
#include <bits/stdc++.h>
typedef long long LL;
const int N = 2e5 + 5;
const LL LLINF = 0x3f3f3f3f3f3f3f3fLL;
int n, K, Q;
LL a[N];
struct Edge { int to, nxt; } edge[N << 1];
int head[N];
void add_edge(int u, int v) { static int k = 1; edge[k] = (Edge){v, head[u]}, head[u] = k++; }
int fa[N], dep[N], son[N];
void dfs(int u) {
dep[u] = dep[fa[u]] + 1;
for(int i = head[u]; i; i = edge[i].nxt) if(edge[i].to != fa[u]) {
int v = edge[i].to;
fa[v] = u;
dfs(v);
if(!son[u] || a[son[u]] > a[v]) son[u] = v;
}
}
struct Matrix { LL a[3][3]; } e, initial;
Matrix mul(Matrix x, Matrix y) {
Matrix z;
for(int i = 0; i < K; i++) for(int j = 0; j < K; j++) z.a[i][j] = LLINF;
for(int i = 0; i < K; i++) for(int j = 0; j < K; j++) for(int k = 0; k < K; k++) z.a[i][k] = std::min(z.a[i][k], x.a[i][j] + y.a[j][k]);
return z;
}
Matrix trans[N];
Matrix prod1[N][21], prod2[N][21];
int go[N][21];
void preprocess() {
for(int i = 1; i <= n; i++) go[i][0] = fa[i], prod1[i][0] = prod2[i][0] = trans[i];
for(int j = 1; j <= 20; j++)
for(int i = 1; i <= n; i++) {
go[i][j] = go[go[i][j - 1]][j - 1];
prod1[i][j] = mul(prod1[i][j - 1], prod1[go[i][j - 1]][j - 1]);
prod2[i][j] = mul(prod2[go[i][j - 1]][j - 1], prod2[i][j - 1]);
}
}
int lca(int u, int v) {
if(dep[u] < dep[v]) std::swap(u, v);
int d = dep[u] - dep[v];
for(int i = 0; i <= 20; i++) if(d >> i & 1) u = go[u][i];
if(u == v) return u;
for(int i = 20; i >= 0; i--) if(go[u][i] != go[v][i]) u = go[u][i], v = go[v][i];
return fa[u];
}
Matrix multiply1(int u, int d) {
Matrix x = e;
for(int i = 0; i <= 20; i++) if(d >> i & 1) x = mul(x, prod1[u][i]), u = go[u][i];
return x;
}
Matrix multiply2(int u, int d) {
Matrix x = e;
for(int i = 0; i <= 20; i++) if(d >> i & 1) x = mul(prod2[u][i], x), u = go[u][i];
return x;
}
int main() {
freopen("transmit.in", "r", stdin);
freopen("transmit.out", "w", stdout);
scanf("%d%d%d", &n, &Q, &K);
for(int i = 1; i <= n; i++) scanf("%lld", &a[i]);
for(int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); add_edge(u, v), add_edge(v, u); }
dfs(1);
for(int i = 0; i < K; i++) for(int j = 0; j < K; j++) e.a[i][j] = (i == j ? 0 : LLINF);
for(int i = 0; i < K; i++) for(int j = 0; j < K; j++) initial.a[i][j] = LLINF;
initial.a[0][K - 1] = 0;
for(int i = 1; i <= n; i++) {
for(int j = 0; j < K; j++) for(int k = 0; k < K; k++) trans[i].a[j][k] = LLINF;
trans[i].a[0][0] = a[i];
if(K >= 2) trans[i].a[0][1] = 0, trans[i].a[1][0] = a[i];
if(K >= 3) trans[i].a[1][2] = 0, trans[i].a[2][0] = a[i], trans[i].a[1][1] = (son[i] ? a[son[i]] : LLINF);
}
preprocess();
while(Q--) {
int u, v;
scanf("%d%d", &u, &v);
int f = lca(u, v);
int du = dep[u] - dep[f] + 1, dv = dep[v] - dep[f] + 1;
LL ans = mul(initial, mul(multiply1(u, du), multiply2(v, dv - 1))).a[0][0];
if(fa[f]) ans = std::min(ans, mul(initial, mul(multiply1(u, du + 1), multiply2(v, dv))).a[0][0]);
printf("%lld\n", ans);
}
return 0;
}
啊终于填完坑了