[算法学习] 换根dp

换根dp

一般来说,我们做题的树都是默认 \(1\) 为根的。但是有些题目需要计算以每个节点为根时的内容。
朴素的暴力:以每个点 \(u\) 作为 \(root\) 暴力dfs下去,复杂度\(O(n^2)\)
正确的做法:换根dp,复杂度\(O(n)\)

执行步骤

  1. 第一次扫描,先默认 \(root=1\) ,跑一遍 \(dfs\)
  2. 第二次扫描,从 \(root=1\) 开始,每次从 \(u\)\(v\) 节点时,计算根从 \(u\) 转移到 \(v\) 时的贡献变化。
    很显然,换根dp是在两个\(dfs\)中完成的,下面我们介绍一下如何运用它。

例题1 Accumulation Degree

题目链接:South Central China 2008 Accumulation Degree

Description

给你一颗有 \(n\) 个节点的树,每一条边连接 \(u_i\)\(v_i\),流量为 \(fl_i\) ,你需要找出一个点作为 \(root\),并最大化从该点出发到所有叶子节点的流量最大值。
多组数据。(PS:题意读不懂的可以结合题目中的图理解,类似网络流的流法)
数据范围 \(1 \le n\le 200000\),并且 \(\sum n \le 200000\)
时间限制 \(1000\ ms\)

Solution

我们先默认这棵树以 \(1\) 为根,跑一次 \(dfs\)
定义 \(flow[i]\) 表示以 \(i\) 为根的子树中流量最大值
那么,\(u\) 节点从儿子 \(v\) 得到的流量为:
1.若\(v\)叶子节点,那么\(flow[u] += flow[v]\)(可以直接流过来);
2.若\(v\)非叶子节点,那么\(flow[u] += min(flow[v], fl(u, v))\)\(u\)\(v\)相连的边有流量限制)。
这样,我们得到了以 \(1\) 为根时的答案,记为 \(f[1]\),它的值等于 \(flow[1]\)
考虑如何换根
\(u\) 为根转移到儿子 \(v\) 为根, \(f[v]\) 包括两部分:一部分是从 \(v\) 流向自己的子树,一部分是从 \(v\) 往父节点走。
那么贡献的变化是第二部分造成的,原本的贡献是 \(flow[u] - min(flow[v], fl(u, v))\),现在加上 \(u\)\(v\) 这条边的流量限制,所以新的贡献是 \(min(fl(u, v), flow[u] - min(flow[v], fl(u, v)))\)
注意如果 \(u\) 的度为 \(1\),则需要特殊处理。
再来一个 \(dfs\) 转移即可。
复杂度 \(O(n)\),可以通过本题。

Code

例题2 STA-Station

题目链接:POI2008 STA-Station

Description

给你一颗有 \(n\) 个节点的树,你需要找出一个点作为 \(root\) ,并最大化 \(\sum_{i=1}^{n} dep_i\)
其中 \(dep_i\) 表示以 \(root\) 为根时,\(i\)节点的深度。
数据范围 \(1\le n\le 10^6\)
时间限制 \(1000\ ms\)

Solution

我们先默认这颗树以 \(1\) 为根,跑一次 \(dfs\),记录\(dep[i]\)\(size[i]\)
接下来,定义 \(f_i\) 表示以 \(i\) 为根时的 \(dep[i]\) 之和。
显然,\(f[1] = \sum_{i=1}^{n} dep[i]\)
当我们从 \(u\) 转移到儿子 \(v\) 时,以 \(v\) 为根的子树内的所有节点 \(dep\) 值都减一,以外的所有节点 \(dep\) 值都加一。
于是有: \(f[v] = f[u] - size[v] + (n - size[v]) = f[u] + n - 2 * size[v]\)
答案即为 \(max_{i=1}^{n} f[i]\)\(i\)
复杂度 \(O(n)\)卡卡常可以通过本题。

Code

这个题目卡\(vector\),能把用\(STL\)的完美卡飞。所以我改成前向星了呜呜呜。

// Author: wlzhouzhuan
#include <bits/stdc++.h>
using namespace std;
  
#define ll long long
#define ull unsigned long long
#define rint register int
#define rep(i, l, r) for (rint i = l; i <= r; i++)
#define per(i, l, r) for (rint i = l; i >= r; i--)
#define mset(s, _) memset(s, _, sizeof(s))
#define pb push_back
#define pii pair <int, int>
#define mp(a, b) make_pair(a, b)
#define Each(i) for (rint i = head[u]; i; i = edge[i].nxt)
inline int read() {
  int x = 0, neg = 1; char op = getchar();
  while (!isdigit(op)) { if (op == '-') neg = -1; op = getchar(); }
  while (isdigit(op)) { x = 10 * x + op - '0'; op = getchar(); }
  return neg * x;
}
inline void print(int x) {
  if (x < 0) { putchar('-'); x = -x; }
  if (x >= 10) print(x / 10);
  putchar(x % 10 + '0');
}


const int N = 1000005;
struct Edge {
  int to, nxt;
} edge[N << 1];
int head[N], tot;
void add(int u, int v) {
  edge[++tot] = {v, head[u]};
  head[u] = tot;
}
int n;

ll f[N];
int sz[N], dep[N];
void dfs1(int u, int fa) {
  sz[u] = 1;
  dep[u] = dep[fa] + 1;
  Each(i) {
    int v = edge[i].to;
    if (v == fa) continue;
    dfs1(v, u);
    sz[u] += sz[v];
  }
}
void dfs2(int u, int fa) {
  Each(i) {
    int v = edge[i].to;
    if (v == fa) continue;
    f[v] = f[u] + n - 2ll * sz[v];
    dfs2(v, u);
  }
}
int main() {
  ios :: sync_with_stdio(false); cin.tie(0);
  cin >> n;
  for (int i = 1; i < n; i++) {
    int u, v;
    cin >> u >> v;
    add(u, v), add(v, u);
  } 
  dfs1(1, 0);
  for (int i = 1; i <= n; i++) f[1] += dep[i];
  dfs2(1, 0);
  cout << max_element(f + 1, f + n + 1) - f << '\n'; 
  return 0;
} 
posted @ 2020-04-06 17:31  wlzhouzhuan  阅读(1383)  评论(3编辑  收藏  举报