「题解」洛谷 P4103 [HEOI2014]大工程
题目
思路
虚树 + dp。
共有 \(n-1\) 条边,\(k\) 个询问点,用 \((u_i,v_i,w_i)\) 表示第 \(i\) 条边是从 \(u_i\) 到 \(v_i\) 花费为 \(w_i\) 的一条边。那么用 \(sz\) 维护每个点的子树中询问点的个数,那么第一问的答案统计每条边的贡献就行了,式子如下:
\[\sum_{i = 1}^{n - 1} (sz_{v_i} \times (k - sz_{v_i}) \times w_i)
\]
第二三问类似,这里以最小值为例,设 \(mn_i\) 表示以 \(i\) 为根的子树中询问点到 \(i\) 的最小距离,\(f[i]\) 表示以 \(i\) 为根的子树中两个询问点之间的最小距离,设 \(i\) 的儿子有 \(a,b,\dots,e\),在更新到了儿子 \(c\) 时,可以用 \(mn[i] + w + mn[c]\) 来更新 \(f[i]\)。
做 \(q\) 次树形 \(dp\) 的复杂度为 \(O(nq)\),是无法承受的复杂度,发现 \(\sum k \le 2 \times n\),因此可以构建虚树来 \(dp\)。
Code
#include <cstdio>
#include <cstring>
#include <string>
#include <iostream>
#include <algorithm>
#define inf 1e9
#define M 1000001
typedef long long ll;
inline int max(int a, int b) { return a > b ? a : b; }
inline int min(int a, int b) { return a < b ? a : b; }
inline void read(int &T) {
int x = 0; bool f = 0; char c = getchar();
while (c < '0' || c > '9') { if (c == '-') f = !f; c = getchar(); }
while (c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar(); }
T = f ? -x : x;
}
ll c, ans;
bool asked[M];
int n, m, k, sc, cnt, s[M], h[M], pthn, minn, maxx, head[M];
int f[M], g[M], sz[M], mn[M], mx[M], lg[M], dfn[M], dep[M], fa[M][21];
struct Edge {
int nxt, to, w;
}pth[M << 1];
void add(int frm, int to, int w) {
pth[++pthn].to = to, pth[pthn].nxt = head[frm];
pth[pthn].w = w, head[frm] = pthn;
}
void dfs(int u, int father) {
dfn[u] = ++cnt, fa[u][0] = father, dep[u] = dep[father] + 1;
for (int i = head[u]; i; i = pth[i].nxt) {
int x = pth[i].to;
if (x != father) dfs(x, u);
}
}
int lca(int x, int y) {
if (dep[x] < dep[y]) std::swap(x, y);
while (dep[x] > dep[y]) {
x = fa[x][lg[dep[x] - dep[y]] - 1];
}
if (x == y) return x;
for (int k_ = lg[dep[x]] - 1; k_ >= 0; --k_) {
if (fa[x][k_] != fa[y][k_]) {
x = fa[x][k_], y = fa[y][k_];
}
}
return fa[x][0];
}
int dis(int x, int y) {
int l = lca(x, y);
return dep[x] + dep[y] - 2 * dep[l];
}
inline bool cmp(int a, int b) { return dfn[a] < dfn[b]; }
void build() {
std::sort(h + 1, h + k + 1, cmp), s[sc = 1] = 1, sz[1] = 0;
pthn = g[1] = mx[1] = head[1] = 0, mn[1] = f[1] = inf;
for (int i = 1; i <= k; ++i) {
if (h[i] == 1) { mn[1] = 0; continue; }
int l = lca(h[i], s[sc]);
if (l != s[sc]) {
while (dfn[l] < dfn[s[sc - 1]]) {
int u = s[sc - 1], v = s[sc];
add(u, v, dis(u, v)), --sc;
}
if (dfn[l] > dfn[s[sc - 1]]) {
head[l] = sz[l] = 0;
f[l] = mn[l] = inf, g[l] = mx[l] = 0;
add(l, s[sc], dis(l, s[sc])), s[sc] = l;
} else add(l, s[sc], dis(l, s[sc])), --sc;
}
mn[h[i]] = head[h[i]] = sz[h[i]] = 0;
f[h[i]] = inf, g[h[i]] = mx[h[i]] = 0, s[++sc] = h[i];
}
for (int i = 1; i < sc; ++i) {
add(s[i], s[i + 1], dis(s[i], s[i + 1]));
}
}
void solve1(int u) {
asked[u] ? sz[u] = 1 : sz[u] = 0;
for (int i = head[u]; i; i = pth[i].nxt) {
int x = pth[i].to;
if (x != fa[u][0]) {
solve1(x), sz[u] += sz[x];
f[u] = min(f[u], mn[u] + mn[x] + pth[i].w);
g[u] = max(g[u], mx[u] + mx[x] + pth[i].w);
if (mx[u] != 0 || asked[u]) maxx = max(maxx, g[u]);
minn = min(minn, f[u]);
mn[u] = min(mn[u], mn[x] + pth[i].w);
mx[u] = max(mx[u], mx[x] + pth[i].w);
}
}
}
void solve2(int u) {
for (int i = head[u]; i; i = pth[i].nxt) {
int x = pth[i].to;
if (x != fa[u][0]) {
ans += 1ll * sz[x] * (k - sz[x]) * pth[i].w;
solve2(x);
}
}
}
int main() {
read(n);
for (int i = 1, u, v; i < n; ++i) {
read(u), read(v);
add(u, v, 1), add(v, u, 1);
}
dfs(1, 0);
for (int i = 1; i <= n; ++i) {
lg[i] = lg[i - 1] + ((1 << lg[i - 1]) == i);
}
for (int j = 1; (1 << j) <= n; ++j) {
for (int i = 1; i <= n; ++i) {
fa[i][j] = fa[fa[i][j - 1]][j - 1];
}
}
read(m);
for (int i = 1; i<= m; ++i) {
read(k), ans = 0, minn = inf, maxx = 0;
c = 1ll * k * (k - 1) / 2;
for (int j = 1; j <= k; ++j) read(h[j]), asked[h[j]] = true;
build(), solve1(1), solve2(1);
for (int j = 1; j <= k; ++j) asked[h[j]] = false;
printf("%lld %d %d\n", ans, minn, maxx);
}
return 0;
}