[BZOJ3611] [Heoi2014]大工程(DP + 虚树)
$dp[i][0]$表示节点i到子树中的所有点的距离之和
$dp[i][1]$表示节点i到子树中最近距离的点的距离
$dp[i][2]$表示节点i到子树中最远距离的点的距离
建好虚树后dp即可。
因为对于虚树掌握的还不是很熟,有些细节还是要注意。
虚树中可能会加入一些lca节点,这些节点在dp的时候是不应该统计的。
对于本题来说,别忘记考虑某一节点不同子树中点对的组合。
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> #define N 2000010 #define LL long long using namespace std; LL ans1, ans2, dp[N][3]; int n, cnt, rp, m, top, T; int head[N], to[N], nex[N], val[N], dis[N], size[N], dfn[N], deep[N], f[N][21], q[N], s[N], flag[N]; inline int read() { int x = 0, f = 1; char ch = getchar(); for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = -1; for(; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + ch - '0'; return x * f; } inline void add(int x, int y) { to[cnt] = y; nex[cnt] = head[x]; head[x] = cnt++; } inline void dfs1(int u) { int i, v; dfn[u] = ++rp; deep[u] = deep[f[u][0]] + 1; for(i = 0; f[u][i]; i++) f[u][i + 1] = f[f[u][i]][i]; for(i = head[u]; ~i; i = nex[i]) { v = to[i]; if(!dfn[v]) { f[v][0] = u; dis[v] = dis[u] + 1; dfs1(v); } } head[u] = -1; } inline int calc_lca(int x, int y) { int i; if(deep[x] < deep[y]) swap(x, y); for(i = 20; i >= 0; i--) if(deep[f[x][i]] >= deep[y]) x = f[x][i]; if(x == y) return x; for(i = 20; i >= 0; i--) if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i]; return f[x][0]; } inline bool cmp(int x, int y) { return dfn[x] < dfn[y]; } inline void dfs2(int u) { int i, v; size[u] = flag[u]; dp[u][1] = 1e9, dp[u][0] = dp[u][2] = 0; for(i = head[u]; ~i; i = nex[i]) { v = to[i]; dfs2(v); size[u] += size[v]; ans1 = min(ans1, dp[u][1] + dp[v][1] + dis[v] - dis[u]); ans2 = max(ans2, dp[u][2] + dp[v][2] + dis[v] - dis[u]); dp[u][0] += dp[v][0] + 1ll * size[v] * (m - size[v]) * (dis[v] - dis[u]); dp[u][1] = min(dp[u][1], dis[v] - dis[u] + dp[v][1]); dp[u][2] = max(dp[u][2], dis[v] - dis[u] + dp[v][2]); } if(flag[u]) { ans1 = min(ans1, dp[u][1]); ans2 = max(ans2, dp[u][2]); dp[u][1] = 0; } head[u] = -1; } inline void solve() { int i, lca; m = read(); top = cnt = 0; for(i = 1; i <= m; i++) q[i] = read(), flag[q[i]] = 1; sort(q + 1, q + m + 1, cmp); for(i = 1; i <= m; i++) { if(!top) { s[++top] = q[i]; continue; } lca = calc_lca(s[top], q[i]); while(dfn[lca] < dfn[s[top]]) { if(dfn[lca] >= dfn[s[top - 1]]) { add(lca, s[top]); if(s[--top] != lca) s[++top] = lca; break; } add(s[top - 1], s[top]), top--; } s[++top] = q[i]; } while(top > 1) add(s[top - 1], s[top]), top--; ans2 = 0; ans1 = 1ll * 1e9 * 1e9; dfs2(s[1]); printf("%lld %lld %lld\n", dp[s[1]][0], ans1, ans2); for(i = 1; i <= m; i++) flag[q[i]] = 0; } int main() { int i, x, y; n = read(); memset(head, -1, sizeof(head)); for(i = 1; i < n; i++) { x = read(); y = read(); add(x, y); add(y, x); } dfs1(1); T = read(); while(T--) solve(); return 0; }