点分治详解(附图附例题)
介绍
点分治, 作为一种统计带权树简单路径长度的暴力分治算法, 其分治方法非常的巧妙, 可以将暴力的 \(O(n^2)\) 优化到 \(O(nlogn)\)
先看问题:
在一个带权树上, 统计两个点的简单路径长度不超过 \(k\) 的路径个数
这就是 模板题1 POJ1741
首先还是考虑如何使用暴力求出, 很明显的我们直接对树上的每个点做一遍 \(dfs\) 即可, 这样的时间复杂度是 \(O(n^2)\)
太慢了, 有没有什么方法可以进行优化?
我们考虑树上的一个点 \(t\), 那么对于该点来说, 可以把问题简单的分为两大类, 第一类是经过点 \(t\) 的路径的点, 另一类是不经过该点的路径的点.
关于第二大类, 也就是不经过 \(t\) 的点, 我们考虑把该点删除之后, 原图会变成若干个无根树, 然后其点一定会在这些子树的某个路径上, 我们可以递归求出
然后考虑计算第一类问题: 可以对该点进行一遍 \(dfs\), 求出每个点到其的距离, 然后对这些距离排序, 这样我们就有了一个有序的距离数组, 那么使用双指针就可以求出此时符合条件的所有可能
然后你会发现, 这些路径中很明显有一些是有问题的, 举例子来说, 当前的 \(k = 7\), 此时有这样一种情况:
从根节点 \(R\) 到 节点 \(X\), \(Y\) 的距离均为 \(3\), 那么此时这个路径已经被纳入合法路径了, 实则不然, 仔细观察发现这条路径实际上是: \(R → X → R → Y\) 实际的长度是 \(9\), 实际上这样是因为它不是一个简单路径,所以不合法,我们需要去除这个不合法的路径, 从整体上来看, 同一子树内的两个节点组成的路径均不合法(即绿色色路径).
如何去除? 我们考虑从子树入手, 我们直接计算以子节点为根的子树, 我们此时我们有 \(R → S\) 这条边, 也就是说, 我们可以把原来的 \(k\) 变成 \(k + W_\text{R → S}\) (即蓝色路径), 这样我们就可以把这些不合法的路径通通去除了
这样, 我们就在 \(O(nlogn)\) 的时间复杂度内求一个点的合法方案得到了, 接下来
为什么要这样做? 如果我们还是一个点一个点的这样求, 那么还徒增了一个排序, 硬是把复杂度优化到了 \(O(n^2logn)\)
很明显我们肯定不可以再一个一个去求了, 我们在上述方法中提到了一点, 也就是在递归的过程中, 要删去之前求出过的节点, 那么会分裂成若干个无根树, 也就是说, 我们需要尽可能的减少这些树的数量, 以达到最优, 那么此时我们最重要的地方出来了: 重心, 没错, 每次都用重心做根节点, 分裂的子树大小不会超过 \(\frac{n}{2}\),我们只需要 \(logn\) 层即可完成递归!
递归复杂度: \(logn\), 每个点的方案复杂度: \(nlogn\), 总复杂度: \(O(nlog^2n)\)
以下是例题代码:
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 2e4 + 10, mod = 1e9 + 7;
struct node{
int v, w;
};
int n, m, k;
int sz[N], dp[N], dist[N], now, res, Tsize, cnt;
vector<node> g[N];
bool vis[N];
void dfs1(int u, int fa){
sz[u] = 1, dp[u] = 0;
for(int i = 0; i < g[u].size(); i++){
int x = g[u][i].v;
if(x == fa || vis[x]) continue;
dfs1(x, u);
sz[u] += sz[x];
dp[u] = max(dp[u], sz[x]);
}
dp[u] = max(dp[u], Tsize - sz[u]);
if(dp[u] < dp[now]) now = u;
}
void dfs2(int u, int fa, int d){
dist[++cnt] = d;
for(int i = 0; i < g[u].size(); i++){
int x = g[u][i].v, w = g[u][i].w;
if(x == fa || vis[x]) continue;
dfs2(x, u, d + w);
}
}
int dfs3(int u, int d){
cnt = 0, dfs2(u, 0, d);
sort(dist + 1, dist + 1 + cnt);
int l = 1, r = cnt, ans = 0;
while(l <= r){
while(r && dist[l] + dist[r] > k) r--;
if(l > r) break;
ans += r - l + 1;
l++;
}
return ans;
}
void dfs(int u){
res += dfs3(u, 0);
vis[u] = true;
for(int i = 0; i < g[u].size(); i++){
int x = g[u][i].v, w = g[u][i].w;
if(!vis[x]){
// 这里实现的步骤是去除子树里面不合法的合并路径
res -= dfs3(x, w);
// 其实这里加一步这个才是正确的一般点分治, 但是不加也可以, 具体证明: https://liu-cheng-ao.blog.uoj.ac/blog/2969
// Tsize = n, dfs1(x, 0);
now = 0, Tsize = sz[x], dfs1(x, 0);
dfs(now);
}
}
}
signed main()
{
std::ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
while(cin >> n >> k && n != 0 && k != 0){
for(int i = 1; i <= n + 100 ; i++) g[i].clear();
memset(vis, false, sizeof vis);
for(int i = 1; i < n; i++){
int a, b, c; cin >> a >> b >> c;
node x = {b, c}, y = {a, c};
g[a].push_back(x), g[b].push_back(y);
}
res = 0;
dp[now = 0] = 1e9, Tsize = n, dfs1(1, 0);
dfs(now);
cout << res - n << '\n';
}
return 0;
}
例题:
【模板】点分治 1
模板点分治1
题目描述
给定一棵有 \(n\) 个点的树,询问树上距离为 \(k\) 的点对是否存在。
输入格式
第一行两个数 \(n,m\)。
第 \(2\) 到第 \(n\) 行,每行三个整数 \(u, v, w\),代表树上存在一条连接 \(u\) 和 \(v\) 边权为 \(w\) 的路径。
接下来 \(m\) 行,每行一个整数 \(k\),代表一次询问。
输出格式
对于每次询问输出一行一个字符串代表答案,存在输出 AYE
,否则输出 NAY
。
样例 #1
样例输入 #1
2 1
1 2 2
2
样例输出 #1
AYE
提示
数据规模与约定
- 对于 \(30\%\) 的数据,保证 \(n\leq 100\)。
- 对于 \(60\%\) 的数据,保证 \(n\leq 1000\),\(m\leq 50\) 。
- 对于 \(100\%\) 的数据,保证 \(1 \leq n\leq 10^4\),\(1 \leq m\leq 100\),\(1 \leq k \leq 10^7\),\(1 \leq u, v \leq n\),\(1 \leq w \leq 10^4\)。
做法:
本题求的是是否存在距离等于 \(k\) 的路径, 那么我们在求答案的时候用二分更为方便, 接下来要注意本题最好不要每次询问都去跑一边, 很可能会 T, 因为算法常数较大, 这里直接统计即可, 复杂度 \(O(mnlog^2n\))
代码:
#include <bits/stdc++.h>
using namespace std;
const int N = 2e4 + 10, mod = 1e9 + 7;
struct node{
int u, w;
};
vector<node> g[N];
int sz[N], dp[N], dist[N], wait[N], ok[N];
int now, Tsize, cnt, k, res, n, m;
bool vis[N];
void dfs1(int u, int fa){
sz[u] = 1, dp[u] = 0;
for(auto it : g[u]){
int x = it.u;
if(x == fa || vis[x]) continue;
dfs1(x, u);
sz[u] += sz[x];
dp[u] = max(dp[u], sz[x]);
}
dp[u] = max(dp[u], Tsize - sz[u]);
if(dp[u] < dp[now]) now = u;
}
void dfs2(int u, int fa, int d){
dist[++cnt] = d;
for(auto it : g[u]){
int x = it.u, w = it.w;
if(x == fa || vis[x]) continue;
dfs2(x, u, d + w);
}
}
void dfs3(int u, int d, bool f){
cnt = 0, dfs2(u, 0, d);
sort(dist + 1, dist + 1 + cnt);
for(int j = 1; j <= m; j++){
for(int i = 1; i < cnt; i++){
int x = lower_bound(dist + i + 1, dist + 1 + cnt, wait[j] - dist[i]) - dist;
if(x && x <= cnt && dist[x] + dist[i] == wait[j] && x != i){
if(f) ok[j]++;
else ok[j]--;
}
}
}
}
void dfs(int u){
dfs3(u, 0, true), vis[u] = true;
for(auto it : g[u]){
int x = it.u, w = it.w;
if(vis[x]) continue;
dfs3(x, w, false);
Tsize = n, dfs1(x, 0);
now = 0, Tsize = sz[x], dfs1(x, 0);
dfs(now);
}
}
signed main()
{
std::ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
cin >> n >> m;
for(int i = 1; i < n; i++){
int a, b, c; cin >> a >> b >> c;
g[a].push_back({b, c}), g[b].push_back({a, c});
}
for(int i = 1; i <= m; i++) cin >> wait[i];
dp[now = 0] = 1e9, Tsize = n, dfs1(1, 0);
dfs(now);
for(int i = 1; i <= m; i++){
if(ok[i]) cout << "AYE" << '\n';
else cout << "NAY" << '\n';
}
return 0;
}
ABC359_G
Sum of Tree Distance
给你一棵有 \(N\) 个顶点的树。 \(i\) 这条边双向连接顶点 \(u_i\) 和 \(v_i\) 。
此外,还给出了一个整数序列 \(A=(A_1,\ldots,A_N)\) 。
在此,定义 \(f(i,j)\) 如下:
- 如果是 \(A_i = A_j\) ,那么 \(f(i,j)\) 就是从顶点 \(i\) 移动到顶点 \(j\) 所需的最小边数。如果是 \(A_i \neq A_j\) ,那么就是 \(f(i,j) = 0\) 。
计算下面表达式的值:
\(\displaystyle \sum_{i=1}^{N-1}\sum_{j=i+1}^N f(i,j)\) .
** 数据限制 **
- \(2 \leq N \leq 2 \times 10^5\)
- \(1 \leq u_i, v_i \leq N\)
- \(1 \leq A_i \leq N\)
- 输入图是一棵树。
- 所有输入值均为整数。
样例 #1
样例输入 #1
4
3 4
4 2
1 2
2 1 1 2
样例输出 #1
4
样例 #2
样例输入 #2
8
8 6
3 8
1 4
7 8
4 5
3 4
8 2
1 2 2 2 3 1 1 3
样例输出 #2
19
思路:
由于是树上距离, 考虑使用点分治或树上差分, 这里使用点分治, 由于点分治是一定能够保证遍历所有的节点的, 所以我们在枚举到每一棵以 \(u\) 为根的树之后, 计入两个全局变量 \(sum\) 和 \(cnt\), 这里的 \(sum_{A_v}\) 是指所有以 \(A_v\) 为颜色的节点到根节点的距离, \(cnt_{A_v}\) 表示所有相同颜色的个数, 那么在遍历一个子节点之前就有: \(res += sum_{A_v} + dist_v * cnt_{A_v}\), 也就是加上之前遍历过的点到 \(v\) 点的距离和。
代码:
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 2e5 + 10, mod = 1e9 + 7;
int sz[N], dep[N], dist[N], dp[N], cnt[N], sum[N], A[N];
int now, n, m, Tsize, res;
bool vis[N];
vector<int> g[N];
vector<int> a, b;
void dfs1(int u, int fa){
sz[u] = 1, dp[u] = 0;
for(auto x : g[u]){
if(x == fa || vis[x]) continue;
dfs1(x, u);
sz[u] += sz[x];
dp[u] = max(dp[u], sz[x]);
}
dp[u] = max(dp[u], Tsize - dp[u]);
if(dp[u] < dp[now]) now = u;
}
void dfs2(int u, int fa, int d){
a.push_back(u);
dep[u] = d;
for(auto x : g[u]){
if(x == fa || vis[x]) continue;
dfs2(x, u, d + 1);
}
}
void dfs3(int u, int d){
a.clear(), dfs2(u, 0, d);
for(auto x : a){
res += sum[A[x]] + dep[x] * cnt[A[x]];
}
}
void dfs(int u){
vis[u] = true, cnt[A[u]] ++;
for(auto x : g[u]){
if(vis[x]) continue;
dfs3(x, 1);
for(auto y : a){
sum[A[y]] += dep[y], cnt[A[y]] ++;
b.push_back(y);
}
}
for(auto x : b){
sum[A[x]] -= dep[x], cnt[A[x]] --;
}
b.clear(), cnt[A[u]] --;
for(auto x : g[u]){
if(vis[x]) continue;
now = 0, Tsize = sz[x], dfs1(x, 0);
dfs(now);
}
}
signed main()
{
std::ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
int n; cin >> n;
for(int i = 1; i < n; i++){
int a, b; cin >> a >> b;
g[a].push_back(b), g[b].push_back(a);
}
for(int i = 1; i <= n; i++) cin >> A[i];
dp[now = 0] = 1e9, Tsize = n, dfs1(1, 0);
dfs(1);
cout << res << '\n';
return 0;
}