[cf 1140] G. Double Tree
题意
给两棵同构的树,将同构节点之间连边,构成一张图。所有边有边权。给出一些询问,求某两点间的最短路。
题解
设两棵树分别为\(T, T'\),同构节点们为\(x, x'\)。
注意到每个询问的答案一定是从\(u\)在某一棵树上走,然后走到另一棵树的同构节点,再在另外一棵树上走,这样的过程重复个若干次。
在一棵树上走一定是走简单路径;走到同构节点并非就是走那条直接相连的边,而是最短路。
先考虑最短路这么求?我们要求\(n\)对同构节点的最短路。
可以等价转化:
1.\(\forall_{e(x, x', w_x)}adde(0, x, w_x)\)
2.\(\forall_{e(x, y, w_1, w_2)} adde(x, y, w_1 + w_2)\)
然后会发现这是对的……很神奇,就可以直接一遍sssp就好啦。
然后可以直接把最短路当做边权了。
那如何处理一整个问题?
记录\(dp_{x, y, u, v}\)代表从树\(u\)的节点\(x\)向上跳\(2 ^ y\)步且最终到达第v棵树上的最短路。其中\(u, v\)取值都是\(\{0, 1\}\)。发现可以把\(dp_{x, y}\)看成一个\(2 * 2\)的矩阵。
为了方便,还要记录\(pd\)数组代表的是从上向下的最短路矩阵。
最后询问的时候倍增跳一跳,矩阵重定义运算一下,然后按顺序合并即可,注意合并的顺序。
复杂度\(O((n + Q) \log n)\)。
#include <bits/stdc++.h>
#define int long long
using namespace std;
typedef long long ll;
inline int read () {
static int x;
scanf("%lld", &x);
return x;
}
inline int readl () {
static ll x;
scanf("%lld", &x);
return x;
}
const int N = 3e5 + 10, M = 2e6 + 10, H = 19;
int n; ll D[N];
struct Graph {
int n, tot;
int lnk[N], nxt[M], son[M]; ll w[M];
void init (int _n) {
n = _n, tot = 1;
memset(lnk, 0, sizeof lnk);
}
void add (int x, int y, ll z) {
nxt[++tot] = lnk[x], lnk[x] = tot, son[tot] = y, w[tot] = z;
}
void adde (int x, int y, ll z) {
add(x, y, z), add(y, x, z);
}
void sssp () {
static ll dis[N];
static bool vis[N];
static queue <int> q;
memset(dis, 60, sizeof dis), dis[0] = 0;
memset(vis, 0, sizeof vis), vis[0] = 1;
for ( ; !q.empty(); q.pop()); q.push(0);
for ( ; !q.empty(); q.pop()) {
int x = q.front();
vis[x] = 0;
for (int j = lnk[x]; j; j = nxt[j])
if (dis[son[j]] > dis[x] + w[j]) {
dis[son[j]] = dis[x] + w[j];
if (!vis[son[j]]) vis[son[j]] = 1, q.push(son[j]);
}
}
for (int i = 1; i <= n; ++i) D[i] = dis[i];
}
} G;
struct Matrix {
ll a[2][2];
Matrix operator * (const Matrix &o) {
return {min(a[0][0] + o.a[0][0], a[0][1] + o.a[1][0]),
min(a[0][0] + o.a[0][1], a[0][1] + o.a[1][1]),
min(a[1][0] + o.a[0][0], a[1][1] + o.a[1][0]),
min(a[1][0] + o.a[0][1], a[1][1] + o.a[1][1])};
}
};
struct Tree {
int n, tot;
int lnk[N], nxt[N << 1], son[N << 1]; ll w1[N << 1], w2[N << 1];
int dep[N], fa[N][H + 1]; Matrix dp[N][H + 1], pd[N][H + 1];
void init (int _n) {
n = _n, tot = 1, dep[0] = 0;
memset(lnk, 0, sizeof lnk);
}
void add (int x, int y, ll z1, ll z2) {
nxt[++tot] = lnk[x], lnk[x] = tot, son[tot] = y, w1[tot] = z1, w2[tot] = z2;
}
void adde (int x, int y, ll z1, ll z2) {
add(x, y, z1, z2), add(y, x, z1, z2);
}
void dfs (int x, int p) {
fa[x][0] = p, dep[x] = dep[p] + 1;
for (int j = lnk[x]; j; j = nxt[j]) if (son[j] != p) {
dfs(son[j], x);
dp[son[j]][0] = {w1[j], min(w1[j] + D[x], D[son[j]] + w2[j]),
min(w1[j] + D[son[j]], D[x] + w2[j]), w2[j]};
pd[son[j]][0] = {w1[j], min(w1[j] + D[son[j]], D[x] + w2[j]),
min(w1[j] + D[x], D[son[j]] + w2[j]), w2[j]};
}
}
void build () {
for (int j = 1; j <= H; ++j)
for (int i = 1; i <= n; ++i) {
fa[i][j] = fa[fa[i][j - 1]][j - 1];
dp[i][j] = dp[i][j - 1] * dp[fa[i][j - 1]][j - 1];
pd[i][j] = pd[fa[i][j - 1]][j - 1] * pd[i][j - 1];
}
}
int lca (int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
int dif = dep[x] - dep[y];
for (int j = H; ~j; --j)
if (dif >> j & 1) x = fa[x][j];
if (x == y) return x;
for (int j = H; ~j; --j)
if (fa[x][j] != fa[y][j]) x = fa[x][j], y = fa[y][j];
return fa[x][0];
}
ll query (int x, int y, int u, int v) {
static Matrix ans1, ans2, ans;
ans1 = {0, D[x], D[x], 0}, ans2 = {0, D[y], D[y], 0};
int z = lca(x, y), dif;
dif = dep[x] - dep[z];
for (int j = H; ~j; --j) if (dif >> j & 1)
ans1 = ans1 * dp[x][j], x = fa[x][j];
dif = dep[y] - dep[z];
for (int j = H; ~j; --j) if (dif >> j & 1)
ans2 = pd[y][j] * ans2, y = fa[y][j];
ans = ans1 * ans2;
return ans.a[u][v];
}
} T;
signed main () {
n = read(), G.init(n), T.init(n);
for (int i = 1; i <= n; ++i) G.adde(0, i, readl());
for (int i = 1; i < n; ++i) {
int x = read(), y = read();
ll z1 = readl(), z2 = readl();
G.adde(x, y, z1 + z2), T.adde(x, y, z1, z2);
}
G.sssp();
T.dfs(1, 0);
T.build();
for (int _ = read(), x, y; _; --_) {
x = read() + 1, y = read() + 1;
printf("%lld\n", T.query(x >> 1, y >> 1, x & 1, y & 1));
}
return 0;
}
不知为何全搞成long long才过。