点分治
点分治是树分治的一种形式,通常用来求满足某种要求的路径数量。
引入
有 \(n\) 个数,问是否存在一个 \(l, r\) 使得区间和为 \(k\),强行用分治做,可以将数组分成两半,递归后处理左边 \(l\) 右边 \(r\),然后就用前缀和加 \(map\) 加归并的并做就可以了。
思路
首先考虑一个暴力:我们对于每一个节点都可以进行查找答案,我们可以暴力的查找答案,对于每一个子树按顺序统计归并,不过这样子会爆炸,所以我们需要一种更加高效的方法
考虑上一个节点(为根)和下一个节点(为根)之间有什么关系,我们想到,如果一次答案在这两点中统计了多次,我们就会有计算冗余,于是我们最佳的方案就是下一个节点计算答案的过程中不算到上一个节点。不过这样的时间复杂度还是有问题。
考虑排布顺序,我们发现引入中要求的是和为 \(k\),所以我们人为设定一个答案范围,这个答案范围就是我们规范的计算位置,对于每个根我们的计算都只在范围内计算,那答案范围我们应该如何定?
说实话,是个人都应该想到了二进制,你想不明白自己画个图吧,最简单的图,都不用点分治
那么好,我们发现答案范围类似二进制后,我们变应该考虑这个根排布在哪,考虑到一个范围可以同时管上下,于是我们可以在上一个根的每一个子树的直径上找到中点,那个中点将会是我们的下一个根(说人话:子树上的重心)
为什么如此,考虑到我们设定了答案范围(别误会,这是个虚拟的)假设我们在当前根的答案范围内找到了上一个根,这个时候不符合我们的预期,也就是说上一个根最好离当前根有答案范围的距离,考虑到我们的答案范围和二进制很像,所以就是子树直径上的中心(子树上的重心)
code
#include <iostream>
#include <vector>
using namespace std;
const int MaxN = 50010;
int cnt[MaxN], sz[MaxN], st[MaxN], tot, n, k, ans;
vector<int> g[MaxN];
bool vis[MaxN];
int find_fatbigest(int x, int fa) {
sz[x] = 1;
int maxs = 0, res = -1;
for (int i : g[x]) {
if (i == fa || vis[i]) continue;
res = find_fatbigest(i, x);
if (res != -1) {
return res;
}
sz[x] += sz[i], maxs = max(maxs, sz[i]);
}
maxs = max(maxs, n - sz[x]);
if (maxs * 2 <= n) {
res = x;
sz[fa] = n - sz[x];
}
return res;
}
void G(int x, int fa, int sum) {
if (sum > k) {
return;
}
st[++tot] = sum;
ans += cnt[k - sum] + (sum == k);
for (int i : g[x]) {
if (i == fa || vis[i]) continue;
G(i, x, sum + 1);
}
}
void DFS(int x) {
for (int i : g[x]) {
if (vis[i]) continue;
int tmp = tot;
G(i, x, 1);
for (int i = tmp + 1; i <= tot; i++) {
cnt[st[i]]++;
}
}
for (int i = 1; i <= tot; i++) {
cnt[st[i]]--;
}
tot = 0;
vis[x] = 1;
for (int i : g[x]) {
if (vis[i]) continue;
n = sz[i];
DFS(find_fatbigest(i, x));
}
}
int main() {
cin >> n >> k;
for (int i = 1, u, v; i < n; i++) {
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
DFS(find_fatbigest(1, 0));
cout << ans << endl;
return 0;
}
P3806 【模板】点分治 1
考虑离线,将所有的 \(k\) 弄下来,然后统计答案是统计 \(m\) 次,因为 \(m\) 小的可怜,所以能过。
code
#include <iostream>
#include <vector>
using namespace std;
using pii = pair<int, int>;
const int MaxN = 1e4 + 10, MaxM = 1e8 + 10;
struct S {
bool cnt[MaxM];
int sz[MaxN], st[MaxN], k[MaxN], tot, n, m, ans;
vector<pii> g[MaxN];
bool vis[MaxN], flag[MaxM];
S() {
tot = n = m = ans = 0;
for (int i = 0; i < MaxN; i++) {
cnt[i] = sz[i] = st[i] = vis[i] = k[i] = 0;
}
}
int find_fatbigest(int x, int fa) {
sz[x] = 1;
int maxs = 0, res = -1;
for (auto i : g[x]) {
if (i.first == fa || vis[i.first]) continue;
res = find_fatbigest(i.first, x);
if (res != -1) {
return res;
}
sz[x] += sz[i.first], maxs = max(maxs, sz[i.first]);
}
maxs = max(maxs, n - sz[x]);
if (maxs * 2 <= n) {
res = x;
sz[fa] = n - sz[x];
}
return res;
}
void G(int x, int fa, int sum) {
st[++tot] = sum;
for (int i = 1; i <= m; i++) {
if (sum <= k[i]) {
flag[k[i]] |= cnt[k[i] - sum] + (sum == k[i]);
}
}
for (auto i : g[x]) {
if (i.first == fa || vis[i.first]) continue;
G(i.first, x, sum + i.second);
}
}
void DFS(int x) {
for (auto i : g[x]) {
if (vis[i.first]) continue;
int tmp = tot;
G(i.first, x, i.second);
for (int i = tmp + 1; i <= tot; i++) {
cnt[st[i]] = 1;
}
}
for (int i = 1; i <= tot; i++) {
cnt[st[i]] = 0;
}
tot = 0;
vis[x] = 1;
for (auto i : g[x]) {
if (vis[i.first]) continue;
n = sz[i.first];
DFS(find_fatbigest(i.first, x));
}
}
void solve(int len, int q, int x[], int a[][3]) {
n = len, m = q;
for (int i = 1; i <= m; i++) {
k[i] = x[i];
}
for (int i = 1; i < n; i++) {
g[a[i][0]].push_back({a[i][1], a[i][2]});
g[a[i][1]].push_back({a[i][0], a[i][2]});
}
DFS(find_fatbigest(1, 0));
}
};
int a[MaxN][3], k[MaxN], n, m;
S t;
int main() {
ios::sync_with_stdio(0), cin.tie(0);
cin >> n >> m;
for (int i = 1; i < n; i++) {
cin >> a[i][0] >> a[i][1] >> a[i][2];
}
for (int i = 1; i <= m; i++) {
cin >> k[i];
}
t.solve(n, m, k, a);
for (int i = 1; i <= m; i++) {
cout << (t.flag[k[i]] ? "AYE" : "NAY") << '\n';
}
return 0;
}