【算法】树分治
1. 算法简介
树分治(Tree division),是处理树上路径类问题的算法。树分治又可以分为点分治与边分治。
考虑这样一个问题:给定一棵有 \(n\) 个点的树,询问树上距离为 \(k\) 的点对是否存在。
暴力的做法就是枚举两个点然后计算距离,统计答案。这样显然 \(O(n^2)\) 的。
我们发现瓶颈在于枚举的过程:我们希望快速地知道树上的路径信息,而不在乎路径上的端点。
这时候就需要使用树分治算法来优化时间。
2. 点分治
点分治是树分治的一种。
大家可能看出来了,上述的例题就是 P3806 【模板】点分治 1。
对于一棵树而言,树上的路径无外乎两种:一种是经过根节点的,另一种是不经过根节点的。(前提是有根树,无根树可转为有根树)
对于经过根节点的路径,想要知道其路径信息是很容易的。但不经过根节点的路径就很难维护了(即所在子树相同)。以当前为根的树很难维护其子树路径的信息。这时候我们便可以删去当前根节点,分裂成许多以儿子节点为子树根的新树。
分裂之前:
分裂之后:
由于每一个节点都当过根节点,这样,树上的所有路径等能被统计到。
但我们会发现,当树为链时,分治的时间复杂度依然为 \(O(n)\) ,没有达到优化时间的目的。
这是,按照原来老老实实从根节点开始分治的方法已经不适用,这时候我们需要找到一个合适的点,使得分治之后,时间复杂度趋近于 \(O(\log n)\)
这个点就是树的重心。
2.1 求重心
树的中心定义为:其所有的子树中最大的子树节点数最少。当删去此点时,生成的多棵新树会趋于平衡,这也会让点分治的时间复杂度趋于 \(O(\log n)\)。
找中心,便一边 \(dfs\) 整棵树即可。
设 \(maxs_x\) 表示 \(x\) 节点的最大的子树大小,\(siz_x\) 表示 \(x\) 节点的子树大小,\(rt\) 为选出的根。
考虑一个节点在原树上的位置。值得注意的是,当此节点不为根节点时,其子树包括其父辈之外的所有节点,像这样:
蓝色圈出部分为 'now' 节点的所有子树。
Code:
void getrt(int x, int fa) {
siz[x] = 1, maxs[x] = 0;
for (int i = h[x]; i; i = e[i].nx) {
int y = e[i].v;
if(y == fa || vis[y]) continue;
getrt(y, x);
siz[x] += siz[y];
if(maxs[x] < siz[y]) maxs[x] = siz[y];
}
maxs[x] = max(maxs[x], sum - siz[x]);
if(maxs[rt] > maxs[x]) rt = x;
}
注意!!!!,当分裂成多个子树之后,分治到子树时,则需重新找重心。来保证程序的时间复杂度。
由于重新找重心是在分治的过程中完成的,故总时间复杂度不会超过 \(O(n\log n)\)。
2.2 分治
找到了重心之后,便可以以重心为根进行分治。
设 \(vis_x\) bool 数组表示 \(x\) 节点是否被“删除”(删除的节点不能被再次遍历,也不能再次进行答案统计)。
由于每次进入下一层分治时要重新找重心,故分治时要及时把 \(maxs_rt\) 设置为 \(siz_y\)(最大不会超过 \(siz_y\)),根节点编号也要设置为 \(0\)。
找完之后,以新重心为根,继续分治,统计答案;
Code:
void divide(int x) {
vis[x] = f[0] = 1;
solve(x);
for (int i = h[x]; i; i = e[i].nx) {
int y = e[i].v;
if(vis[y]) continue;//遍历过的点包含其父亲节点
maxs[rt = 0] = sum = siz[y];
getrt(y, 0);
divide(rt);
}
}
当你在统计答案时想要使用 \(siz\) 时,直接在 'getrt(y, 0);' 补一句 'getrt(rt, 0)' 即可
2.3 答案统计
这里因题而异,拿模板题举例。
设 \(f_x\) 表示 \(x\) 在当前状态内是否出现,\(tmp_i\) 存储有多少种不同的路径长度。
先可以把新树的点到根的距离求出来:
Code:
void getdis(int x, int fa) {
tmp[++cnt] = dis[x];
for (int i = h[x]; i; i = e[i].nx) {
int y = e[i].v;
if(y == fa || vis[y]) continue;
dis[y] = dis[x] + e[i].w;
getdis(y, x);
}
}
然后统计答案
Code:
void solve(int x) {
int H = 1, t = 0;
for (int i = h[x]; i; i = e[i].nx) {
int y = e[i].v;
if(vis[y]) continue;
dis[y] = e[i].w;
cnt = 0;
getdis(y, x);
For(j,1,cnt) {
For(k,1,m) {
if(Q[k] >= tmp[j]) ans[k] |= f[Q[k] - tmp[j]];
}
}
For(j,1,cnt) {
q[++t] = tmp[j];
f[tmp[j]] = 1;
}
}
while(H <= t) {
f[q[H]] = 0;
H++;
}
}
3. 点分治例题
3.1 P4178 Tree
Proble
给定一棵有 \(n\) 个节点的树,每条边有边权,求出树上两点距离小于等于 \(k\) 的点对数量。
Solve
点分治模板。
考虑统计答案时,令路径长度数组为 \(tmp\),当新出现边权 \(tmp_x\) 时,看已有的边权里是否出现 \(tmp_y\) 使得 \(tmp_x + tmp_y \le k\)。推到得 \(tmp_y <= k - tmp_x\),所以将区间 \([0,k-tmp_x]\) 计入答案,再单点 \(tmp_x\) 加一即可。
树状数组可维护。
Code
#include <bits/stdc++.h>
#define int long long
#define rint register int
#define For(i,l,r) for(rint i=l;i<=r;++i)
#define FOR(i,r,l) for(rint i=r;i>=l;--i)
#define MOD 1000003
#define mod 1000000007
#define inf 0x3f3f3f3f3f3f3f3f
using namespace std;
namespace Read {
template <typename T>
inline void read(T &x) {
x=0;T f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
x*=f;
}
template <typename T, typename... Args>
inline void read(T &t, Args&... args) {
read(t), read(args...);
}
}
using namespace Read;
void print(int x){
if(x<0){putchar('-');x=-x;}
if(x>9){print(x/10);putchar(x%10+'0');}
else putchar(x+'0');
return;
}
const int N = 4e4 + 10;
struct Node {
int v, w, nx;
} e[N << 1];
int n, h[N], k, tot, rt, sum, siz[N], maxs[N], t[N], dis[N], tmp[N], cnt, ans;
bool vis[N];
void add(int u, int v, int w) {
e[++tot] = (Node) {v, w, h[u]};
h[u] = tot;
}
void getrt(int x, int fa) {
siz[x] = 1, maxs[x] = 0;
for (int i = h[x]; i; i = e[i].nx) {
int y = e[i].v;
if(y == fa || vis[y]) continue;
getrt(y, x);
siz[x] += siz[y];
if(maxs[x] < siz[y]) maxs[x] = siz[y];
}
maxs[x] = max(maxs[x], sum - siz[x]);
if(maxs[rt] > maxs[x]) rt = x;
}
void getdis(int x, int fa) {
tmp[++cnt] = dis[x];
for (int i = h[x]; i; i = e[i].nx) {
int y = e[i].v;
if(y == fa || vis[y]) continue;
dis[y] = dis[x] + e[i].w;
getdis(y, x);
}
}
int lb(int x) {
return x & -x;
}
int qry(int x) {
int Ans = 0;
for (int i = x; i; i -= lb(i)) {
Ans += t[i];
}
return Ans;
}
void upd(int x, int z) {
for (int i = x; i <= k + 1; i += lb(i)) {
t[i] += z;
}
}
void solve(int x) {
for (int i = h[x]; i; i = e[i].nx) {
int y = e[i].v;
if(vis[y]) continue;
dis[y] = e[i].w;
cnt = 0;
getdis(y, x);
For(j,1,cnt) {
if(tmp[j] <= k) ans += qry(k - tmp[j] + 1);
}
For(j,1,cnt) {
if(tmp[j] <= k) upd(tmp[j] + 1, 1);
}
}
memset(t, 0, sizeof t);
upd(1, 1);
}
void divide(int x) {
vis[x] = 1;
solve(x);
for (int i = h[x]; i; i = e[i].nx) {
int y = e[i].v;
if(vis[y]) continue;
maxs[rt = 0] = n; sum = siz[y];
getrt(y, 0);
divide(rt);
}
}
signed main() {
read(n);
For(i,1,n-1) {
int u, v, w;
read(u, v, w);
add(u, v, w);
add(v, u, w);
}
read(k);
maxs[0] = sum = n;
getrt(1, 0);
upd(1, 1);
divide(rt);
cout << ans << '\n';
return 0;
}
3.2 P4149 [IOI2011] Race
Problem
给一棵树,每条边有权。求一条简单路径,权值和等于 \(k\),且边的数量最小。
Solve
点分治模板
在模板题的基础上多加一个记录深度,每次记录深度时取最大值,然后找到权值之和为 \(k\) 的路径就用深度更新答案。
当 \(k=tmp_j\) 时,无需再找路径进行拼接,直接更新。
存贮时要判断 \(tmp_j\le k\),直接存 \(tmp_j\) 可能会爆掉。
Code
#include <bits/stdc++.h>
#define ll long long
#define rint register int
#define For(i,l,r) for(rint i=l;i<=r;++i)
#define FOR(i,r,l) for(rint i=r;i>=l;--i)
#define MOD 1000003
#define mod 1000000007
#define inf 0x3f3f3f3f
using namespace std;
namespace Read {
template <typename T>
inline void read(T &x) {
x=0;T f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
x*=f;
}
template <typename T, typename... Args>
inline void read(T &t, Args&... args) {
read(t), read(args...);
}
}
using namespace Read;
void print(int x){
if(x<0){putchar('-');x=-x;}
if(x>9){print(x/10);putchar(x%10+'0');}
else putchar(x+'0');
return;
}
const int N = 2e5 + 10, M = 2e6 + 10;
struct Node {
int v, w, nx;
} e[N << 1];
int n, k, h[N], tot, maxs[N], f[M], tmp[M], rt, sum, cnt, siz[N], ans = inf, q[M], dep[M], dis[N];
bool vis[N];
void add(int u, int v, int w) {
e[++tot] = (Node){v, w, h[u]};
h[u] = tot;
}
void getrt(int x, int fa) {
maxs[x] = 0, siz[x] = 1;
for (int i = h[x]; i; i = e[i].nx) {
int y = e[i].v;
if(y == fa || vis[y]) continue;
getrt(y, x);
siz[x] += siz[y];
if(maxs[x] < siz[y]) maxs[x] = siz[y];
}
maxs[x] = max(maxs[x], sum - siz[x]);
if(maxs[rt] > maxs[x]) rt = x;
}
void getdis(int x, int fa, int dp) {
tmp[++cnt] = dis[x];
dep[cnt] = dp;
for (int i = h[x]; i; i = e[i].nx) {
int y = e[i].v;
if(y == fa || vis[y]) continue;
dis[y] = dis[x] + e[i].w;
getdis(y, x, dp + 1);
}
}
void solve(int x) {
int H = 1, t = 0;
for (int i = h[x]; i; i = e[i].nx) {
int y = e[i].v;
if(vis[y]) continue;
dis[y] = e[i].w;
cnt = 0;
getdis(y, x, 1);
For(j,1,cnt) {
if(k >= tmp[j] && f[k - tmp[j]] != inf) {
ans = min(ans, dep[j] + f[k - tmp[j]]);
}
if(k == tmp[j]) {
ans = min(ans, dep[j]);
}
}
For(j,1,cnt) {
if(k >= tmp[j]) {
f[tmp[j]] = min(f[tmp[j]], dep[j]);
q[++t] = tmp[j];
}
}
}
while(H <= t) {
f[q[H]] = inf; H++;
}
}
void divide(int x) {
vis[x] = 1, f[0] = inf;
solve(x);
for (int i = h[x]; i; i = e[i].nx) {
int y = e[i].v;
if(vis[y]) continue;
maxs[rt = 0] = n, sum = siz[y];
getrt(y, 0);
divide(rt);
}
}
signed main() {
read(n, k);
For(i,1,n-1) {
int u, v, w;
read(u, v, w);
u++, v++;
add(u, v, w);
add(v, u, w);
}
memset(f, 0x3f, sizeof f);
maxs[0] = sum = n;
getrt(1, 0);
divide(rt);
if(ans != inf) cout << ans << '\n';
else puts("-1");
return 0;
}