点分治学习笔记
点分治学习笔记
淀粉质一般用来处理树上的点对问题,代码复杂度适中 。
本文中通过以下例题来解释点分治,这是一道经典题 (poj1741):
给出一棵树,有边权,求树上路径长度不超过 \(k\) 的条数。
可以直接考虑一棵以节点 \(u\) 为根的树的答案,显然可以分成两种情况:
1.经过 u 的路径;
2.不经过 u 的路径。
接着从 \(u\) 开始 \(dfs\) 处理出子树上的点到 \(u\) 的距离,将所有的距离放进一个数组 \(c\) ,然后排序后用尺取法就可以得到经过 \(u\) 的路径条数,这样就解决了情况 \(1\) ??????
显然不是这样,还需要考虑一种情况:某两点的 \(lca\) 并不是 \(u\) ,但是他们到 \(u\) 的距离之和并没有超过 \(k\),只需要把 \(u\) 的子节点上的答案减掉就好了(可以看代码理解)。
每次求出经过 \(u\) 的路径,就可以将 \(u\) 从树中删除,因为它不再对答案有贡献;这样就将树分成了若干棵树,递归求解即可。
选择删除的点时可以选择重心,每次分出来的树大小最少减少为原来的一半,所以最多递归 \(logn\) 层。
这样复杂度就是 \(O(n\) \(log^2n)\) 还有一个 \(log\) 是在尺取时的排序操作。
代码如下:
#include <cstdio>
#include <cstring>
#include <algorithm>
inline int in() {
int x=0;char c=getchar();bool f=false;
while(c<'0'||c>'9') f|=c=='-', c=getchar();
while(c>='0'&&c<='9') x=(x<<1)+(x<<3)+(c^48), c=getchar();
return x;
}
const int N = 1e4+5, inf = 0x3f3f3f3f;
struct edge {
int next, to, w;
}e[N<<1];
int cnt=1, head[N], n, k, rt, tot, res, dep[N], siz[N], f[N], b[N];
bool vis[N];
inline void jb(int u, int v, int w) {
e[++cnt].next=head[u];
e[cnt].to=v;
e[cnt].w=w;
head[u]=cnt;
}
void dfs(int u, int fa) {
siz[u]=1, f[u]=0;
for(int i=head[u];i;i=e[i].next) {
int v=e[i].to;
if(v==fa||vis[v]) continue;
dfs(v, u);
f[u]=std::max(f[u], siz[v]);
siz[u]+=siz[v];
}
f[u]=std::max(f[u], tot-siz[u]);
if(f[u]<f[rt]) rt=u;
}
void get_dep(int u, int fa) {
b[++b[0]]=dep[u];
for(int i=head[u];i;i=e[i].next) {
int v=e[i].to;
if(v==fa||vis[v]) continue;
dep[v]=dep[u]+e[i].w;
get_dep(v, u);
}
}
inline int work(int u, int s) {
dep[u]=s, b[0]=0;
get_dep(u, -1);
std::sort(b+1, b+1+b[0]);
int ret=0;
for(int l=1, r=b[0];l<r;++l) {
while(l<r&&b[l]+b[r]>k) --r;
ret+=r-l;
}
return ret;
}
void pp(int u) {
res+=work(u, 0);
vis[u]=true;
for(int i=head[u];i;i=e[i].next) {
int v=e[i].to;
if(vis[v]) continue;
res-=work(v, e[i].w);
rt=0, tot=siz[v], dfs(v, u), pp(rt);
}
}
inline void init() {
cnt=1, res=0;
memset(head, 0, sizeof(head));
memset(vis, 0, sizeof(vis));
}
int main() {
while(true) {
init();
n=in(), k=in();
if(!n&&!k) break;
for(int i=1, x, y, z;i<n;++i) {
x=in(), y=in(), z=in();
jb(x, y, z), jb(y, x, z);
}
tot=n;
f[rt=0]=inf;
dfs(1, -1);
pp(rt);
printf("%d\n", res);
}
return 0;
}
放几道练习题: