「题解」洛谷 P4103 [HEOI2014]大工程

题目

P4103 [HEOI2014]大工程

思路

虚树 + dp。

共有 \(n-1\) 条边,\(k\) 个询问点,用 \((u_i,v_i,w_i)\) 表示第 \(i\) 条边是从 \(u_i\)\(v_i\) 花费为 \(w_i\) 的一条边。那么用 \(sz\) 维护每个点的子树中询问点的个数,那么第一问的答案统计每条边的贡献就行了,式子如下:

\[\sum_{i = 1}^{n - 1} (sz_{v_i} \times (k - sz_{v_i}) \times w_i) \]

第二三问类似,这里以最小值为例,设 \(mn_i\) 表示以 \(i\) 为根的子树中询问点到 \(i\) 的最小距离,\(f[i]\) 表示以 \(i\) 为根的子树中两个询问点之间的最小距离,设 \(i\) 的儿子有 \(a,b,\dots,e\),在更新到了儿子 \(c\) 时,可以用 \(mn[i] + w + mn[c]\) 来更新 \(f[i]\)

\(q\) 次树形 \(dp\) 的复杂度为 \(O(nq)\),是无法承受的复杂度,发现 \(\sum k \le 2 \times n\),因此可以构建虚树来 \(dp\)

Code

#include <cstdio>
#include <cstring>
#include <string>
#include <iostream>
#include <algorithm>

#define inf 1e9
#define M 1000001
typedef long long ll;
inline int max(int a, int b) { return a > b ? a : b; }
inline int min(int a, int b) { return a < b ? a : b; }
inline void read(int &T) {
  int x = 0; bool f = 0; char c = getchar();
  while (c < '0' || c > '9') { if (c == '-') f = !f; c = getchar(); }
  while (c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar(); }
  T = f ? -x : x;
}

ll c, ans;
bool asked[M];
int n, m, k, sc, cnt, s[M], h[M], pthn, minn, maxx, head[M];
int f[M], g[M], sz[M], mn[M], mx[M], lg[M], dfn[M], dep[M], fa[M][21];
struct Edge {
  int nxt, to, w;
}pth[M << 1];

void add(int frm, int to, int w) {
  pth[++pthn].to = to, pth[pthn].nxt = head[frm];
  pth[pthn].w = w, head[frm] = pthn;
}

void dfs(int u, int father) {
  dfn[u] = ++cnt, fa[u][0] = father, dep[u] = dep[father] + 1;
  for (int i = head[u]; i; i = pth[i].nxt) {
    int x = pth[i].to;
    if (x != father) dfs(x, u);
  }
}

int lca(int x, int y) {
  if (dep[x] < dep[y]) std::swap(x, y);
  while (dep[x] > dep[y]) {
    x = fa[x][lg[dep[x] - dep[y]] - 1];
  }
  if (x == y) return x;
  for (int k_ = lg[dep[x]] - 1; k_ >= 0; --k_) {
    if (fa[x][k_] != fa[y][k_]) {
      x = fa[x][k_], y = fa[y][k_];
    }
  }
  return fa[x][0];
}

int dis(int x, int y) {
  int l = lca(x, y);
  return dep[x] + dep[y] - 2 * dep[l];
}

inline bool cmp(int a, int b) { return dfn[a] < dfn[b]; }

void build() {
  std::sort(h + 1, h + k + 1, cmp), s[sc = 1] = 1, sz[1] = 0;
  pthn = g[1] = mx[1] = head[1] = 0, mn[1] = f[1] = inf;
  for (int i = 1; i <= k; ++i) {
    if (h[i] == 1) { mn[1] = 0; continue; }
    int l = lca(h[i], s[sc]);
    if (l != s[sc]) {
      while (dfn[l] < dfn[s[sc - 1]]) {
        int u = s[sc - 1], v = s[sc];
        add(u, v, dis(u, v)), --sc;
      }
      if (dfn[l] > dfn[s[sc - 1]]) {
        head[l] = sz[l] = 0;
        f[l] = mn[l] = inf, g[l] = mx[l] = 0;
        add(l, s[sc], dis(l, s[sc])), s[sc] = l;
      } else add(l, s[sc], dis(l, s[sc])), --sc;
    }
    mn[h[i]] = head[h[i]] = sz[h[i]] = 0;
    f[h[i]] = inf, g[h[i]] = mx[h[i]] = 0, s[++sc] = h[i];
  }
  for (int i = 1; i < sc; ++i) {
    add(s[i], s[i + 1], dis(s[i], s[i + 1]));
  }
}

void solve1(int u) {
  asked[u] ? sz[u] = 1 : sz[u] = 0;
  for (int i = head[u]; i; i = pth[i].nxt) {
    int x = pth[i].to;
    if (x != fa[u][0]) {
      solve1(x), sz[u] += sz[x];
      f[u] = min(f[u], mn[u] + mn[x] + pth[i].w);
      g[u] = max(g[u], mx[u] + mx[x] + pth[i].w);
      if (mx[u] != 0 || asked[u]) maxx = max(maxx, g[u]);
      minn = min(minn, f[u]);
      mn[u] = min(mn[u], mn[x] + pth[i].w);
      mx[u] = max(mx[u], mx[x] + pth[i].w);
    }
  }
}

void solve2(int u) {
  for (int i = head[u]; i; i = pth[i].nxt) {
    int x = pth[i].to;
    if (x != fa[u][0]) {
      ans += 1ll * sz[x] * (k - sz[x]) * pth[i].w;
      solve2(x);
    }
  }
}

int main() {
  read(n);
  for (int i = 1, u, v; i < n; ++i) {
    read(u), read(v);
    add(u, v, 1), add(v, u, 1);
  }
  dfs(1, 0);
  for (int i = 1; i <= n; ++i) {
    lg[i] = lg[i - 1] + ((1 << lg[i - 1]) == i);
  }
  for (int j = 1; (1 << j) <= n; ++j) {
    for (int i = 1; i <= n; ++i) {
      fa[i][j] = fa[fa[i][j - 1]][j - 1];
    }
  }
  read(m);
  for (int i = 1; i<= m; ++i) {
    read(k), ans = 0, minn = inf, maxx = 0;
    c = 1ll * k * (k - 1) / 2;
    for (int j = 1; j <= k; ++j) read(h[j]), asked[h[j]] = true;
    build(), solve1(1), solve2(1);
    for (int j = 1; j <= k; ++j) asked[h[j]] = false;
    printf("%lld %d %d\n", ans, minn, maxx);
  }
  return 0;
}
posted @ 2021-01-20 22:20  yu__xuan  阅读(92)  评论(0编辑  收藏  举报