Loading

P8820 [CSP-S 2022] 数据传输 题解

思路

考完发现比较模板的一道题目。

由于需要树上不断的进行 \(dp\)

考虑使用倍增和广义矩乘。

首先发现需要求出的值为最小值,所以可以考虑使用 \(min,+\) 矩乘。

发现这一道题比较麻烦的一点是 \(k=3\) 时,它可以跳到一个不属于从 \(s\)\(t\) 的简单路径上的一个点。

但对于这一点也同样可以直接在矩阵上面修改一下即可。

我们需要存一个 \(up_i\)\(down_i\) 数组,表示所需的倍增转移矩阵。

一个是朝上跳,一个来朝下跳。

考虑如何来设计这样一个矩阵。

\[\begin{vmatrix} val_i&0&inf\\ val_i&min_i&0\\ val_i&inf&inf \end{vmatrix} \]

比较轻松的发现,这样一个矩阵就满足了我们的转移需求。

之后求直接套倍增板子就可以了。

Code

#include <bits/stdc++.h>
#define int long long
using namespace std;

const int N = 200005;

int n, q, k, dep[N], a[N], minn[N], fa[N][18];

int cnt, head[N];

inline int read()
{
    int asd = 0 , qwe = 1; char zxc;
    while(!isdigit(zxc = getchar())) if(zxc == '-') qwe = -1;
    while(isdigit(zxc)) asd = asd * 10 + zxc - '0' , zxc = getchar();
    return asd * qwe;
}

struct edge
{
    int to, nxt, val;
} e[N * 2];

struct Mat
{
    int a[3][3];
    Mat() { memset(a, 0x3f, sizeof(a)); }
    inline Mat operator*(const Mat &y)
    {
        Mat sum;
        for (int i = 0; i <= 2; i++)
            for (int j = 0; j <= 2; j++)
                for (int k = 0; k <= 2; k++)
                    sum.a[i][j] = min(sum.a[i][j], a[i][k] + y.a[k][j]);
        return sum;
    }
} dp[N], up[N][18], down[N][18];

inline void add(int x, int y, int z = 1)
{
    e[++cnt] = {y, head[x], z}, head[x] = cnt;
    e[++cnt] = {x, head[y], z}, head[y] = cnt;
}

inline void dfs(int now, int ff)
{
    dep[now] = dep[ff] + 1 , fa[now][0] = ff;
    up[now][0] = down[now][0] = dp[fa[now][0]];
    for (int i = 1; i <= 17; i++)
    {
        fa[now][i] = fa[fa[now][i - 1]][i - 1];
        up[now][i] = up[now][i - 1] * up[fa[now][i - 1]][i - 1];
        down[now][i] = down[fa[now][i - 1]][i - 1] * down[now][i - 1];
    }
    for(int i = head[now];i;i = e[i].nxt)
        if (e[i].to != ff) dfs(e[i].to, now);
}

inline int ask(int x, int y)
{
    if (dep[x] < dep[y])
        swap(x, y);
    Mat u, d = dp[y];
    u.a[0][0] = a[x];
    for (int i = 17; i >= 0; i--)
        if (dep[fa[x][i]] >= dep[y])
            u = u * up[x][i], x = fa[x][i];
    if (x == y) return u.a[0][0];
    for (int i = 17; i >= 0; i--)
        if (fa[x][i] != fa[y][i])
            u = u * up[x][i], d = down[y][i] * d,
            x = fa[x][i], y = fa[y][i];
    return (u * dp[fa[x][0]] * d).a[0][0];
}

signed main()
{
    memset(minn, 0x3f, sizeof(minn));
    n = read() , q = read() , k = read();
    for (int i = 1; i <= n; i++)
        a[i] = read();
    for (int i = 1; i < n; i++)
    {
        int x = read() , y = read();
        minn[x] = min(minn[x], a[y]);
        minn[y] = min(minn[y], a[x]);
        add(x, y);
    }
    for (int i = 1; i <= n; i++)
    {
        dp[i].a[0][0] = a[i];
        if (k != 1) dp[i].a[1][0] = a[i], dp[i].a[0][1] = 0;
        if (k == 3) dp[i].a[2][0] = a[i], dp[i].a[1][2] = 0, dp[i].a[1][1] = minn[i];
    }
    dfs(1, 0);
    for (int i = 1, x, y; i <= q; i++)
    {
        x = read() , y = read();
        printf("%lld\n", ask(x, y));
    }
    return 0;
}
posted @ 2022-11-01 15:42  JiaY19  阅读(333)  评论(0编辑  收藏  举报