点分治小结
算法介绍
点分治,顾名思义,是一种对点进行分治的数据结构。(树上的点)
多用于在树上进行有限制的路径计数。
比如:求树上长度小于$ k$ 的简单路径条数。\((n \leq 10000)\)
直接做肯定是补星的。所以就需要点分治这种东西了。
需要统计的路径肯定有这么两种:
- 1.经过根节点$ root $的路径
- 2.不经过根节点\(root\)的路径
显然第二种路径对于某个节点\(u\),属于第一种路径。所以分治解决即可。
我们来考虑第一种情况如何解决。
处理出一个数组\(d\),表示从当前根节点\(u\),到各个子节点的距离。
那么我们要统计的显然就是\(d[u]+d[v]\leq k\)的路径\((u,v)\)的个数。
这个东西可以通过在dfs求这个数组时顺便把所有的\(d\)值记录下来,排序之后让他们具有单调性。
然后双指针扫一下就好(合法状态就是\(d[l]+d[r]\leq k\))那么当指针在\(l\)时,对答案的贡献就是\(r-l\)(不能重复选自己,所以不+1)
然后现在考虑一种情况。当\(u,v\)都在当前根节点的同一个子树里面。这样子的话,路径\((u,v)\)如果经过根节点就不是一条简单路径了(重边)。如何解决呢?
容斥的思想!
对于每个子树,分别处理它其中的子节点的d值,给答案减去就行了!
代码大概就长这个样子
void dfs(int u) {
vis[u] = 1;
ans += solve(u, 0); //所有情况
for(int i = head[u]; i; i = e[i].nxt) {
if(vis[e[i].to]) continue;
int v = e[i].to;
ans -= solve(v, e[i].v); //减掉不合法情况
//下面是找重心的代码,后面会解释为什么要找重心
now_sz = inf, root = 0; sz = siz[v];
find_root(v, 0);
dfs(root);
}
}
先不管为什么要找重心。我们总结一下算法流程:
- 1.找一个根节点root
- 2.对root计算出d数组并计算答案
- 3.把root删了,对root的各个子树执行流程1,2
复杂度是多少呢?粗略估计一下是\(O(Tnlogn) \),\(T\)是树的层数。(这里有个\(log\)是因为用了排序)
显然我们要让这个树优美一点,身材圆润一点,不能瘦成一条链,不然复杂度就变成\(O(n^2logn) \)了。
那这个根节点怎么找呢?树的重心!
将重心当做根节点,可以保证树是\(log\)层的!
那么复杂度就是$O(nlog^2n) \(了!(如果不使用排序的话(比如一些题是用到的桶),那么复杂度是\)O(nlogn)$)
还有就是关于点分治这里的重心有两种找法。一种就是上面那样的,另外一种就是改了一句
sz = siz[v];
->sz = siz[v] > siz[u] ? totsiz - siz[u] : siz[v];
实际上第二种才是对的,因为v可能在上次处理siz数组时是u的父亲(这是一棵无根树!)
但是复杂度并不会退化qwq,有神仙证明了。链接
例题:
POJ1741 tree
真正的模板题。就是我上面提到的那个问题。
直接点分一下就好了。每次将距离排序一下,然后双指针扫一扫,每次合法答案就是r-l,容斥一下将不合法情况减去即可。注意找重心不要写错不然复杂度就炸了。
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define inf 0x3f3f3f3f
#define ll long long
#define N 100010
inline void in(int &x) {
x = 0; int f = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
x = x * 10 + c - '0';
c = getchar();
}
x *= f;
}
int n, k, d[N], cnt, head[N], ans;
int vis[N], siz[N];
struct edge {
int to, nxt, v;
}e[N<<1];
void ins(int u, int v, int w) {
e[++cnt] = (edge) {v, head[u], w};
head[u] = cnt;
}
int now_sz = inf, root = 0, sz;
void find_root(int u, int fa) {
siz[u] = 1;
int res = 0;
for(int i = head[u]; i; i = e[i].nxt) {
if(vis[e[i].to] || e[i].to == fa) continue;
int v = e[i].to;
find_root(v, u);
siz[u] += siz[v];
res = max(res, siz[v]);
}
res = max(res, sz - siz[u]);
if(res < now_sz) now_sz = res, root = u;
}
int a[N], tot;
void get_dis(int u, int fa) {
a[++tot] = d[u];
for(int i = head[u]; i; i = e[i].nxt) {
if(vis[e[i].to] || e[i].to == fa) continue;
int v = e[i].to;
d[v] = d[u] + e[i].v;
get_dis(v, u);
}
}
int solve(int u, int dis) {
d[u] = dis; tot = 0;
get_dis(u, u);
sort(a + 1, a + tot + 1);
int l = 1, r = tot, res = 0;
for(; l < r; ++l) {
while(l < r && a[l] + a[r] > k) --r;
if(l < r) res += r - l;
}
return res;
}
void dfs(int u) {
vis[u] = 1;
ans += solve(u, 0);
for(int i = head[u]; i; i = e[i].nxt) {
if(vis[e[i].to]) continue;
int v = e[i].to;
ans -= solve(v, e[i].v);
now_sz = inf, root = 0; sz = siz[v];
find_root(v, 0);
dfs(root);
}
}
int main() {
while(~scanf("%d%d", &n, &k) && n && k) {
ans = 0; cnt = 0;
memset(head, 0, sizeof(head));
memset(vis, 0, sizeof(vis));
for(int i = 1; i < n; ++i) {
int u, v, w; in(u), in(v), in(w);
ins(u, v, w), ins(v, u, w);
}
dfs(1);
printf("%d\n", ans);
}
}
BZOJ2152: 聪聪可可
求倍数为3的路径数。
考虑\(mod\ 3\)意义下的路径,为0显然可以互相拼起来,贡献是\(sum[0]^2\)。1和2可以互相拼,而且起点终点互换,所以贡献是\(sum[1]*sum[2]*2\),点分治计算这两个即可。总方案数是\(n^2\),所以答案就是\(\frac{sum}{n^2}\)
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define inf 0x3f3f3f3f
#define ll long long
#define N 100010
inline void in(int &x) {
x = 0; int f = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
x = x * 10 + c - '0';
c = getchar();
}
x *= f;
}
int n, k, d[N], cnt, head[N], ans;
int vis[N], siz[N], sum[3];
struct edge {
int to, nxt, v;
}e[N<<1];
void ins(int u, int v, int w) {
e[++cnt] = (edge) {v, head[u], w};
head[u] = cnt;
}
int now_siz, sz, root;
void find_root(int u, int fa) {
siz[u] = 1; int res = 0;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa || vis[v]) continue;
find_root(v, u);
siz[u] += siz[v];
res = max(res, siz[v]);
}
res = max(res, sz - siz[u]);
if(res < now_siz) now_siz = res, root = u;
}
void get_dis(int u, int fa) {
sum[d[u]%3]++;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(vis[v] || v == fa) continue;
d[v] = d[u] + e[i].v;
get_dis(v, u);
}
}
int solve(int u, int dis) {
d[u] = dis; sum[0] = sum[1] = sum[2] = 0;
get_dis(u, u);
return sum[0] * sum[0] + sum[1] * sum[2] * 2;
}
void dfs(int u) {
ans += solve(u, 0);
vis[u] = 1;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(vis[v]) continue;
ans -= solve(v, e[i].v);
now_siz = inf; sz = siz[v]; root = 0;
find_root(v, u);
dfs(root);
}
}
int main() {
in(n);
for(int i = 1; i < n; ++i) {
int u, v, w; in(u), in(v), in(w);
ins(u, v, w), ins(v, u, w);
}
now_siz = inf; root = 0; sz = n;
find_root(1, 1);
dfs(root);
int now = n * n, g = __gcd(now, ans);
printf("%d/%d\n", ans / g, now / g);
}
LuoguP3806 【模板】点分治1
注意这题数据很水...
求长度为k的路径是否存在。多次询问(询问数\(\leq 100\))
这题效率有点奇怪...
自己估算了一下是\(O(mnlog^2n)\)。
对长度正好k的话,其实用个桶标记就好了,实际上和小于k没多大区别的。
考虑先将询问离线,然后在点分治过程中对所有答案进行判定。处理出d[]表示到节点i到当前根的距离。那么照例是拼路径,但是现在不是求方案总数而是求有没有这个方案,看起来不能容斥了。但是实际上可以的:考虑先对根u solve一遍,给所有询问加上这次的结果,然后对每个子节点计算一遍,给所有询问减掉这次的结果就好了。
具体的话看看代码吧
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define inf 0x3f3f3f3f
#define ll long long
#define N 100010
#define lim 10000000
inline void in(int &x) {
x = 0; int f = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
x = x * 10 + c - '0';
c = getchar();
}
x *= f;
}
int top, n, m, d[N], cnt, head[N], ans[110];
int vis[N], siz[N], q[110], st[N], s[10000010];
struct edge {
int to, nxt, v;
}e[N<<1];
void ins(int u, int v, int w) {
e[++cnt] = (edge) {v, head[u], w};
head[u] = cnt;
}
int now_sz = inf, root, sz;
void find_root(int u, int fa) {
siz[u] = 1; int res = 0;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa || vis[v]) continue;
find_root(v, u);
res = max(res, siz[v]);
siz[u] += siz[v];
}
res = max(res, sz - siz[u]);
if(res < now_sz) now_sz = res, root = u;
}
void get_dis(int u, int fa) {
st[++top] = d[u];
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa || vis[v]) continue;
d[v] = d[u] + e[i].v;
get_dis(v, u);
}
}
void solve(int u, int dis, int op) {
top = 0; d[u] = dis; get_dis(u, 0);
for(int i = 1; i <= top; ++i) if(st[i] <= lim) s[st[i]]++;
for(int i = 1; i <= m; ++i) {
for(int j = 1; j <= top; ++j) if(q[i] >= st[j]) ans[i] += s[q[i] - st[j]] * op;
}
for(int i = 1; i <= top; ++i) if(st[i] <= lim) s[st[i]]--;
}
void dfs(int u) {
vis[u] = 1;
solve(u, 0, 1);
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(vis[v]) continue;
top = 0; d[v] = e[i].v;
solve(v, e[i].v, -1);
now_sz = inf, root = 0, sz = siz[v];
find_root(v, u);
dfs(root);
}
}
int main() {
in(n), in(m);
for(int i = 1; i < n; ++i) {
int u, v, w; in(u), in(v), in(w);
ins(u, v, w), ins(v, u, w);
}
for(int i = 1; i <= m; ++i) in(q[i]);
sz = n; now_sz = inf; root = 0;
find_root(1, 1); dfs(root);
for(int i = 1; i <= m; ++i) puts(ans[i] ? "AYE" : "NAY");
}
CF161D Distance in Tree
求长度等于k的路径数...就很烦....这种一般都要分类讨论
需要分类讨论一下,同样是套路点分然后开个桶,然后分\(k-v[i]=v[i]\)和不等两种情况,显然相等的话答案就是\(cnt[v[i]]*(cnt[v[i]]-1)/2\).不相等的话用乘法原理考虑一下,\(cnt[v[i]]*cnt[k-v[i]]\),注意每次统计完之后就要把cnt清空。
#include <bits/stdc++.h>
#define ll long long
#define inf 0x3f3f3f3f
#define il inline
namespace io {
#define in(a) a = read()
#define out(a) write(a)
#define outn(a) out(a), putchar('\n')
#define I_int ll
inline I_int read() {
I_int x = 0, f = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
x = x * 10 + c - '0';
c = getchar();
}
return x * f;
}
char F[200];
inline void write(I_int x) {
if (x == 0) return (void) (putchar('0'));
I_int tmp = x > 0 ? x : -x;
if (x < 0) putchar('-');
int cnt = 0;
while (tmp > 0) {
F[cnt++] = tmp % 10 + '0';
tmp /= 10;
}
while (cnt > 0) putchar(F[--cnt]);
}
#undef I_int
}
using namespace io;
using namespace std;
#define N 100010
int n, k;
int cnt, head[N], vis[N], d[N];
struct edge {
int to, nxt;
}e[N<<1];
void ins(int u, int v) {
e[++cnt] = (edge) {v, head[u]};
head[u] = cnt;
}
int siz[N], now_sz = inf, root, sz;
void find_root(int u, int fa) {
siz[u] = 1; int res = 0;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa || vis[v]) continue;
find_root(v, u);
siz[u] += siz[v];
res = max(res, siz[v]);
}
res = max(res, sz - siz[u]);
if(res < now_sz) now_sz = res, root = u;
}
int top, st[N], s[N];
void get_dis(int u, int fa) {
st[++top] = d[u]; if(d[u] <= k) ++s[d[u]];
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa || vis[v]) continue;
d[v] = d[u] + 1;
get_dis(v, u);
}
}
ll solve(int u, int dis) {
d[u] = dis; top = 0; get_dis(u, 0);
ll ans = 0;
for(int i = 1; i <= top; ++i)
if(st[i] <= k) {
if(st[i] * 2 == k) ans += 1ll * s[st[i]] * (s[st[i]] - 1) / 2ll;
else ans += 1ll * s[k - st[i]] * s[st[i]];
s[st[i]] = s[k - st[i]] = 0;
}
return ans;
}
ll ans = 0;
void dfs(int u) {
vis[u] = 1; ans += solve(u, 0);
int totsiz = sz;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(vis[v]) continue;
ans -= solve(v, 1);
sz = siz[v] > siz[u] ? totsiz - siz[u] : siz[v];
now_sz = inf; root = 0;
find_root(v, 0);
dfs(root);
}
}
int main() {
in(n), in(k);
for(int i = 1; i < n; ++i) {
int u = read(), v = read();
ins(u, v), ins(v, u);
}
now_sz = inf; sz = n; root = inf;
find_root(1, 0);
dfs(root);
outn(ans);
}