树分治学习笔记(未完成)
前言
点分治不应该算数据结构,它的本质是分治的思想。
问题引入
对于一个序列 \(a\),求是否存在 \((l, r)\) 使得 \(\sum\limits_{i=l}^{r}a_i=k\)。\(n\le 10^6,|a_i|\le 10^9\)。
本题显然是有其它的做法的,由于学的是点分治,所以考虑分治做法。
首先对 \(a\) 求前缀和,记这个数组为 \(s\)。
假设当前的分治区间为 \([l, r]\),分治的左儿子和右儿子维护的区间分别为 \([l,mid],(mid, r]\)。当它的子问题已经完成求解时,考虑如何合并,即计算满足条件的 \([l', r'](l\le mid < r)\)。这时候将 \(s_l\sim s_{mid}\) 和 \(s_{mid+1}\sim s_r\) 分别排序,然后就可以维护两个指针 \(i\in[l, mid], j\in(mid,r]\),\(j\) 表示的是 \(\ge s_i+k\) 的最小的 \(j\),归并 \(s\) 数组可以做到 \(\mathcal{O}(n\log n)\)。
算法介绍
令当前的分治的树为 \(T\)。参考于普通的序列分治,考虑将其分为更小规模进行求解,为保证复杂度,可以选取树的重心。即 \(T\) 的所有儿子的子树为 \(T_1, T_2, T_3,\dots,T_c\) 时,对于所有 \(T_i\) 均选取它的重心进行求解,然后按一定顺序合并起来即可。
例题
I. P3806 【模板】点分治 1
给定一棵大小为 \(n\) 的树,边带权,\(m\) 次询问是否存在长度为 \(k\) 的路径。\(n\le 10^4,m\le 100,k\le 10^7\)。
板子题,可以在分治时开桶统计,也可以维护两个指针。
代码:
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define vi vector < int >
#define eb emplace_back
#define pii pair < int, int >
#define fi first
#define se second
#define TIME 1e3 * clock() / CLOCKS_PER_SEC
int Mbe;
mt19937_64 rng(35);
constexpr int N = 1e4 + 10;
int n, m, rt;
int q[110], ans[110], sz[N], mx[N], vis[N];
vector < pii > e[N];
void findroot(int u, int f, int num) {
sz[u] = 1, mx[u] = 0;
for(auto i : e[u]) {
int v = i.fi, w = i.se;
if(v == f || vis[v]) continue;
findroot(v, u, num);
sz[u] += sz[v];
mx[u] = max(mx[u], sz[v]);
}
mx[u] = max(mx[u], num - sz[u]);
if(mx[u] < mx[rt]) rt = u;
}
vector < pii > gd;
void getdep(int u, int fath, int anc, int d) {
gd.eb(pii(d, anc));
for(auto i : e[u]) {
int v = i.fi;
if(v == fath || vis[v]) continue;
getdep(v, u, anc, d + i.se);
}
}
void divide(int u) {
vis[u] = 1;
for(auto i : e[u]) {
int v = i.fi;
if(vis[v]) continue;
getdep(v, u, v, i.se);
}
gd.eb(pii(0, u));
sort(gd.begin(), gd.end());
for(int i = 1; i <= m; ++i) {
int l = 0, r = gd.size() - 1;
while(l < r && !ans[i]) {
if(gd[l].fi + gd[r].fi > q[i]) --r;
else if(gd[l].fi + gd[r].fi < q[i]) ++l;
else {
ans[i] |= gd[l].se != gd[r].se;
if(gd[l].fi == gd[l + 1].fi) ++l;
else --r;
}
}
}
gd.clear();
for(auto i : e[u]) {
int v = i.fi;
if(vis[v]) continue;
rt = 0;
findroot(v, u, sz[v]);
divide(rt);
}
}
int Med;
int main() {
fprintf(stderr, "%.3lf MB\n", (&Mbe - &Med) / 1048576.0);
ios :: sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin >> n >> m;
for(int i = 1; i < n; ++i) {
int u, v, w;
cin >> u >> v >> w;
e[u].eb(pii(v, w));
e[v].eb(pii(u, w));
}
for(int i = 1; i <= m; ++i) cin >> q[i];
mx[0] = N;
findroot(1, 0, n);
divide(rt);
for(int i = 1; i <= m; ++i) cout << (ans[i] ? "AYE" : "NAY") << "\n";
cerr << TIME << "ms\n";
return 0;
}
II. P2634 [国家集训队] 聪聪可可
给定一棵大小为 \(n\) 的树,边带权,求有多少条路径的长度为 \(3\) 的倍数。\(n\le 2\times10^4\)。
这题有很多种解法,但是我们使用点分治解决。
因为求的是 \(3\) 的倍数的数量,可以直接开一个大小为 \(3\) 的桶,记录一下每条路径到当前重心 \(r\) 的路径余数为 \(i\) 的数量,拼起来为 \(3\) 的倍数的路径的方案 \((1,2),(2,1),(0,0)\) 三种,先全部拼起来,然后减掉子树内的路径互相拼的情况即可。
代码:
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define vi vector < int >
#define eb emplace_back
#define pii pair < int, int >
#define fi first
#define se second
#define TIME 1e3 * clock() / CLOCKS_PER_SEC
int Mbe;
mt19937_64 rng(35);
ll gcd(ll a, ll b) {
while(b) swap(b, a %= b);
return a;
}
constexpr int N = 2e4 + 10;
int n, rt;
ll ans;
int vis[N], mx[N], sz[N];
int head[N], cnt_e;
struct edge {
int to, w, nxt;
} e[N << 1];
void adde(int u, int v, int w) {
++cnt_e, e[cnt_e].to = v, e[cnt_e].w = w, e[cnt_e].nxt = head[u], head[u] = cnt_e;
}
void findroot(int u, int fath, int num) {
mx[u] = 0, sz[u] = 1;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fath || vis[v]) continue;
findroot(v, u, num);
sz[u] += sz[v];
mx[u] = max(mx[u], sz[v]);
}
mx[u] = max(mx[u], num - sz[u]);
if(mx[u] < mx[rt]) rt = u;
}
ll t[3];
void getdep(int u, int fath, int dep) {
++t[dep % 3];
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fath || vis[v]) continue;
getdep(v, u, dep + e[i].w);
}
}
ll calc(int u, int d) {
t[0] = t[1] = t[2] = 0;
getdep(u, 0, d);
return t[1] * t[2] * 2 + t[0] * t[0];
}
void divide(int u) {
vis[u] = 1;
ans += calc(u, 0);
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(vis[v]) continue;
ans -= calc(v, e[i].w);
}
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(vis[v]) continue;
rt = 0;
findroot(v, u, sz[v]);
divide(rt);
}
}
int Med;
int main() {
fprintf(stderr, "%.3lf MB\n", (&Mbe - &Med) / 1048576.0);
ios :: sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin >> n;
for(int i = 1; i < n; ++i) {
int u, v, w;
cin >> u >> v >> w;
adde(u, v, w);
adde(v, u, w);
}
mx[0] = N;
findroot(1, 0, n);
divide(rt);
ll g = gcd(ans, n * 1ll * n);
cout << ans / g << "/" << n * 1ll * n / g << "\n";
cerr << TIME << "ms\n";
return 0;
}
III. P4149 [IOI2011] Race
给一棵边带权的树,求一条边数最少且路径长度等于 \(k\) 的路径,输出这条路径的长度。\(n\le 2\times 10^5,k\le 10^6\)。
因为 \(k\le 10^6\),可以直接开一个桶,每次先扫描一棵子树更新答案后再将这个子树的信息加到桶里面即可,桶维护的是长度为 \(i\) 的到当前重心 \(r\) 路径的边数最小值。桶的清空可以用一个栈记录下变动的位置。
代码:
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define vi vector < int >
#define eb emplace_back
#define pii pair < int, int >
#define fi first
#define se second
#define TIME 1e3 * clock() / CLOCKS_PER_SEC
int Mbe;
mt19937_64 rng(35);
constexpr int N = 2e5 + 10, M = 1e6 + 10, inf = 0x3f3f3f3f;
int n, m, rt, ans = inf;
int mx[N], sz[N], vis[N];
int head[N], cnt_e;
struct edge {
int to, w, nxt;
} e[N << 1];
void adde(int u, int v, int w) {
++cnt_e, e[cnt_e].to = v, e[cnt_e].w = w, e[cnt_e].nxt = head[u], head[u] = cnt_e;
}
void findroot(int u, int fath, int num) {
mx[u] = 0, sz[u] = 1;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fath || vis[v]) continue;
findroot(v, u, num);
sz[u] += sz[v];
mx[u] = max(mx[u], sz[v]);
}
mx[u] = max(mx[u], num - sz[u]);
if(mx[u] < mx[rt]) rt = u;
}
int buc[M], stk[M], tp;
void calc(int u, int fath, int d, int val) {
if(val > m) return;
ans = min(ans, buc[m - val] + d);
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fath || vis[v]) continue;
calc(v, u, d + 1, val + e[i].w);
}
}
void getbuc(int u, int fath, int d, int val) {
if(val > m) return;
if(buc[val] == inf) stk[++tp] = val;
buc[val] = min(buc[val], d);
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fath || vis[v]) continue;
getbuc(v, u, d + 1, val + e[i].w);
}
}
void divide(int u) {
vis[u] = 1;
buc[0] = 0, stk[++tp] = 0;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(vis[v]) continue;
calc(v, u, 1, e[i].w);
getbuc(v, u, 1, e[i].w);
}
while(tp) buc[stk[tp--]] = inf;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(vis[v]) continue;
rt = 0;
findroot(v, u, sz[v]);
divide(rt);
}
}
int Med;
int main() {
fprintf(stderr, "%.3lf MB\n", (&Mbe - &Med) / 1048576.0);
ios :: sync_with_stdio(0);
cin.tie(0); cout.tie(0);
memset(buc, inf, sizeof(buc));
cin >> n >> m;
for(int i = 1; i < n; ++i) {
int u, v, w;
cin >> u >> v >> w;
++u, ++v;
adde(u, v, w);
adde(v, u, w);
}
mx[0] = N;
findroot(1, 0, n);
divide(rt);
cout << (ans == inf ? -1 : ans) << "\n";
cerr << TIME << "ms\n";
return 0;
}