点分治
点分治
点分治是用于解决树上路径问题的一中算法,是一种带优化的暴力。
先来一道题看看点分治是啥:
题目大意:给定一颗带权无根树,问有多少条路径之和为3的倍数。(\(n \le 2e4\))
首先这道题肯定是可以DP做的,但我们要用点分治做。想一想,这道题用暴力怎么做,是不是枚举两个点,暴力算,复杂度\(O(n^3)\)。考虑怎么优化一下,预处理每个点到1的距离,再枚举两个点,求这两个点的\(lca\),复杂度\(O(n^2logn)\)。
考虑找到一个比较好的根,把所有点到它的距离算出来,我们找一定经过这个根的路径。怎么找这个根比较优秀呢?我们知道有个东西叫树的重心。
树的重心:树上所有点到这个点的距离的最大值最小。
如果我们找的是树的重心,那么子树就会被分的尽量平均。我们找完所有经过这个点的路径后,把这个点删去,再对各个子树进行同样的操作,类似于分治的思想,所以这种方法叫点分治。
它的复杂度是\(O(nlog^2n)\)的,重心有个性质:每一颗子树的大小都不超过\(\frac{2}{n}\),如果超过了\(\frac{2}{n}\),显然有一颗子树小于\(\frac{2}{n} -2\),这与重心的定义不符,我们可以找到更好的重心。这样递归下去最多是\(logn\)层。
如何实现:
找重心
totsize = max_siz[0] = n; root = 0;
void get_root(int x, int fa) {
siz[x] = 1; max_siz[x] = 0;
for(int i = head[x]; i; i = e[i].nxt) {
int y = e[i].to; if(y == fa || vis[y]) continue;
get_root(y, x); siz[x] += siz[y]; max_siz[x] = max(max_siz[x], siz[y]);
}
max_siz[x] = max(max_siz[x], totsize - siz[x]);
if(max_siz[root] > max_siz[x]) root = x;
}
搞个简单的树形DP就好了。
\(siz\)代表子树大小,\(maxsiz\)表示最大子树,\(root\)是重心,\(totsize\)表示这颗树的节点数。
一个要注意的地方是:一个点的儿子也包括它上面那个点,是平常说的父亲。
分治
void solve(int x) {
/*
dosomething
*/
vis[x] = 1;
for(int i = head[x]; i; i = e[i].nxt) {
int y = e[i].to; if(vis[y]) continue;
root = 0; totsize = siz[y];
get_root(y, 0); solve(root);
}
}
注意\(totsize\)要跟着变。
那对于上面那道题只需修改一下\(solve\)函数就好了。
#include <iostream>
#include <cstdio>
#include <cctype>
using namespace std;
inline long long read() {
long long s = 0, f = 1; char ch;
while(!isdigit(ch = getchar())) (ch == '-') && (f = -f);
for(s = ch ^ 48;isdigit(ch = getchar()); s = (s << 1) + (s << 3) + (ch ^ 48));
return s * f;
}
const int N = 2e4 + 5;
int n, cnt, ans, root, totsize;
int t[N], d[N], vis[N], siz[N], max_siz[N], head[N];
struct edge { int to, nxt, val; } e[N << 1];
void add(int x, int y, int z) {
e[++cnt].nxt = head[x]; head[x] = cnt; e[cnt].to = y; e[cnt].val = z;
}
void get_dis(int x, int fa) {
t[d[x]]++;
for(int i = head[x]; i; i = e[i].nxt) {
int y = e[i].to; if(y == fa || vis[y]) continue;
d[y] = (d[x] + e[i].val) % 3;
get_dis(y, x);
}
}
int calc(int x, int val) {
t[0] = t[1] = t[2] = 0;
d[x] = val;
get_dis(x, 0);
return t[0] * t[0] + t[1] * t[2] * 2;
}
void get_root(int x, int fa) {
// cout << x << endl;
siz[x] = 1; max_siz[x] = 0;
for(int i = head[x]; i; i = e[i].nxt) {
int y = e[i].to; if(y == fa || vis[y]) continue;
get_root(y, x); siz[x] += siz[y]; max_siz[x] = max(max_siz[x], siz[y]);
}
max_siz[x] = max(max_siz[x], totsize - siz[x]);
if(max_siz[root] > max_siz[x]) root = x;
}
void solve(int x) {
ans += calc(x, 0); vis[x] = 1;
for(int i = head[x]; i; i = e[i].nxt) {
int y = e[i].to; if(vis[y]) continue;
ans -= calc(y, e[i].val);
root = 0; totsize = siz[y];
get_root(y, 0); solve(root);
}
}
int gcd(int x, int y) {
return y == 0 ? x : gcd(y, x % y);
}
int main() {
n = read();
for(int i = 1, x, y, z;i <= n - 1; i++) {
x = read(); y = read(); z = read() % 3; add(x, y, z); add(y, x, z);
}
totsize = max_siz[0] = n; root = 0;
get_root(1, 0); solve(root);
int g = gcd(ans, n * n);
printf("%d/%d", ans / g, n * n / g);
return 0;
}
具体解释一下\(calc\)函数:\(t_0, t_1, t_2\)表示有多少条%3之后为0, 1, 2的道路。\(d\)就是一个点到重心的距离。
\(t_1 * t_2 * 2 + t_0 * t_0\)就是经过当前重心的合法路径条数。。。吗?是不是有一些点算重复了,我们需要减去同一颗子树到重心距离为3的倍数的路径,这就完了。
再看一道模板题:P3806 【模板】点分治1
题目大意:给定一棵有\(n\)个点的树,询问树上距离为\(k\)的点对是否存在。
这道题和刚刚那道题很像,更改一下\(calc\)函数就好了,我们还是处理出来每个点到重心的距离,然后用二分查找一个值,看是否可以凑成\(k\),最后在判断一下是否在同一棵子树内就好了。
#include <iostream>
#include <cstdio>
#include <cctype>
#include <algorithm>
using namespace std;
inline long long read() {
long long s = 0, f = 1; char ch;
while(!isdigit(ch = getchar())) (ch == '-') && (f = -f);
for(s = ch ^ 48;isdigit(ch = getchar()); s = (s << 1) + (s << 3) + (ch ^ 48));
return s * f;
}
const int N = 1e4 + 5, K = 1e8 + 5;
int n, m, cnt, tot, root, totsize;
int ki[N], dis[N], ans[K], siz[N], vis[N], max_siz[N], head[N];
struct edge { int to, nxt, val; } e[N << 1];
struct node {
int d, rt;
node() {}
node(int x, int y) { d = x; rt = y; }
} a[K];
int cmp(node x, node y) {
return x.d < y.d;
}
void add(int x, int y, int z) {
e[++cnt].nxt = head[x]; head[x] = cnt; e[cnt].to = y; e[cnt].val = z;
}
void get_root(int x, int fa) {
siz[x] = 1; max_siz[x] = 0;
for(int i = head[x]; i; i = e[i].nxt) {
int y = e[i].to; if(y == fa || vis[y]) continue;
get_root(y, x); siz[x] += siz[y];
max_siz[x] = max(max_siz[x], siz[y]);
}
max_siz[x] = max(max_siz[x], totsize - siz[x]);
if(max_siz[root] > max_siz[x]) root = x;
}
int find(int x) {
int l = 1, r = tot, res = 0;
while(l <= r) {
int mid = (l + r) >> 1;
if(a[mid].d < x) l = mid + 1;
else res = mid, r = mid - 1;
}
return res;
}
void get_dis(int x, int fa, int rt) {
a[++tot] = node(dis[x], rt);
for(int i = head[x]; i; i = e[i].nxt) {
int y = e[i].to; if(y == fa || vis[y]) continue;
dis[y] = dis[x] + e[i].val;
get_dis(y, x, rt);
}
}
void calc(int x, int v) {
dis[x] = v; tot = 0;
for(int i = head[x]; i; i = e[i].nxt) {
int y = e[i].to; if(vis[y]) continue;
dis[y] = dis[x] + e[i].val; get_dis(y, x, y);
}
a[++tot] = node(0, 0); //这里为啥要再加个0点,因为如果不加,就无法选上第一个数,底下那里就会break掉
sort(a + 1, a + tot + 1, cmp);
for(int i = 1;i <= m; i++) {
if(ans[i]) continue;
int l = 1;
while(l <= tot && ki[i] > a[l].d + a[tot].d) l++;
while(l <= tot && !ans[i]) {
if(ki[i] - a[l].d < a[l].d) break;
int tmp = find(ki[i] - a[l].d);
while(a[tmp].d + a[l].d == ki[i] && a[tmp].rt == a[l].rt) tmp++;
if(a[tmp].d + a[l].d == ki[i]) ans[i] = 1;
l++;
}
}
}
void solve(int x) {
calc(x, 0);
vis[x] = 1;
for(int i = head[x]; i; i = e[i].nxt) {
int y = e[i].to; if(vis[y]) continue;
root = 0; totsize = siz[y]; max_siz[0] = n;
get_root(y, 0); solve(root);
}
}
int main() {
n = read(); m = read();
for(int i = 1, x, y, z;i <= n - 1; i++) x = read(), y = read(), z = read(), add(x, y, z), add(y, x, z);
max_siz[0] = totsize = n; root = 0;
for(int i = 1;i <= m; i++) ki[i] = read();
get_root(1, 0); solve(root);
for(int i = 1;i <= m; i++) {
if(ans[i]) puts("AYE"); else puts("NAY");
}
return 0;
}
动态点分治
咕咕咕