P8820 [CSP-S 2022] 数据传输 题解
思路
考完发现比较模板的一道题目。
由于需要树上不断的进行 \(dp\)。
考虑使用倍增和广义矩乘。
首先发现需要求出的值为最小值,所以可以考虑使用 \(min,+\) 矩乘。
发现这一道题比较麻烦的一点是 \(k=3\) 时,它可以跳到一个不属于从 \(s\) 到 \(t\) 的简单路径上的一个点。
但对于这一点也同样可以直接在矩阵上面修改一下即可。
我们需要存一个 \(up_i\) 和 \(down_i\) 数组,表示所需的倍增转移矩阵。
一个是朝上跳,一个来朝下跳。
考虑如何来设计这样一个矩阵。
\[\begin{vmatrix}
val_i&0&inf\\
val_i&min_i&0\\
val_i&inf&inf
\end{vmatrix}
\]
比较轻松的发现,这样一个矩阵就满足了我们的转移需求。
之后求直接套倍增板子就可以了。
Code
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 200005;
int n, q, k, dep[N], a[N], minn[N], fa[N][18];
int cnt, head[N];
inline int read()
{
int asd = 0 , qwe = 1; char zxc;
while(!isdigit(zxc = getchar())) if(zxc == '-') qwe = -1;
while(isdigit(zxc)) asd = asd * 10 + zxc - '0' , zxc = getchar();
return asd * qwe;
}
struct edge
{
int to, nxt, val;
} e[N * 2];
struct Mat
{
int a[3][3];
Mat() { memset(a, 0x3f, sizeof(a)); }
inline Mat operator*(const Mat &y)
{
Mat sum;
for (int i = 0; i <= 2; i++)
for (int j = 0; j <= 2; j++)
for (int k = 0; k <= 2; k++)
sum.a[i][j] = min(sum.a[i][j], a[i][k] + y.a[k][j]);
return sum;
}
} dp[N], up[N][18], down[N][18];
inline void add(int x, int y, int z = 1)
{
e[++cnt] = {y, head[x], z}, head[x] = cnt;
e[++cnt] = {x, head[y], z}, head[y] = cnt;
}
inline void dfs(int now, int ff)
{
dep[now] = dep[ff] + 1 , fa[now][0] = ff;
up[now][0] = down[now][0] = dp[fa[now][0]];
for (int i = 1; i <= 17; i++)
{
fa[now][i] = fa[fa[now][i - 1]][i - 1];
up[now][i] = up[now][i - 1] * up[fa[now][i - 1]][i - 1];
down[now][i] = down[fa[now][i - 1]][i - 1] * down[now][i - 1];
}
for(int i = head[now];i;i = e[i].nxt)
if (e[i].to != ff) dfs(e[i].to, now);
}
inline int ask(int x, int y)
{
if (dep[x] < dep[y])
swap(x, y);
Mat u, d = dp[y];
u.a[0][0] = a[x];
for (int i = 17; i >= 0; i--)
if (dep[fa[x][i]] >= dep[y])
u = u * up[x][i], x = fa[x][i];
if (x == y) return u.a[0][0];
for (int i = 17; i >= 0; i--)
if (fa[x][i] != fa[y][i])
u = u * up[x][i], d = down[y][i] * d,
x = fa[x][i], y = fa[y][i];
return (u * dp[fa[x][0]] * d).a[0][0];
}
signed main()
{
memset(minn, 0x3f, sizeof(minn));
n = read() , q = read() , k = read();
for (int i = 1; i <= n; i++)
a[i] = read();
for (int i = 1; i < n; i++)
{
int x = read() , y = read();
minn[x] = min(minn[x], a[y]);
minn[y] = min(minn[y], a[x]);
add(x, y);
}
for (int i = 1; i <= n; i++)
{
dp[i].a[0][0] = a[i];
if (k != 1) dp[i].a[1][0] = a[i], dp[i].a[0][1] = 0;
if (k == 3) dp[i].a[2][0] = a[i], dp[i].a[1][2] = 0, dp[i].a[1][1] = minn[i];
}
dfs(1, 0);
for (int i = 1, x, y; i <= q; i++)
{
x = read() , y = read();
printf("%lld\n", ask(x, y));
}
return 0;
}