点分治 学习笔记
板子题
题目传送门
给定一棵 \(n\) 个节点的树,每条边有边权,求出树上两点距离小于等于 \(k\) 的点对数量。
\(n\le 4\times 10^4\)
算法解析
显然我们发现如果计算从每个节点开始的点对数量是 \(O(n^2)\) 的,显然是不行的,但是我们发现这是一个计数题,所以我们可以做点分治。
我们发现如果我们选取一个节点作为根,那么所有的节点就会分为两种:过根的路径和没有过根的路径,没有过根的路径我们可以把根删去再对每一棵子树计算答案,直到只剩下一个节点。
如果只计算过根的路径的话,首先我们先通过 DP 求出所有节点到根的距离,然后用排序用双指针法求出答案,再减去两端再同一棵子树的路径数量,就可以算出过根的路径总和;当然也可以用线段树来计算,这里不过多叙述,单次处理复杂度 \(O(n\log n)\)。
但是我们发现,如果我们随机选一个节点最为根的话,如果我们每次都选了叶子节点,就会导致我们要去掉很多次才能算出答案,最坏情况下需要递归 \(n\) 次才可以求解,这样的复杂度是 \(O(n^2\log n)\) ,显然这不是最优的。
我们发现,如果每次将根选作这棵树的 重心 ,那么效率将会大大提升,因为当我们选取树的重心作为根的时候,去掉这个节点之后剩下的最大联通块的节点数是最小的,所以这样就会让一个树分成更小的部分,从而使分割的次数更少。所以说我们需要选择树的重心作为根来计算,只需要最多递归 \(\log n\) 次。这样算法复杂度就是 \(O(n\log^2n)\) 了。
为了方便,删除节点并不是真正的删除,显然只需要打一个标记即可,对应再代码里面就是 mark
数组。
代码
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
#define maxn 40039
using namespace std;
//#define debug
typedef int Type;
inline Type read(){
Type sum=0;
int flag=0;
char c=getchar();
while((c<'0'||c>'9')&&c!='-') c=getchar();
if(c=='-') c=getchar(),flag=1;
while('0'<=c&&c<='9'){
sum=(sum<<1)+(sum<<3)+(c^48);
c=getchar();
}
if(flag) return -sum;
return sum;
}
int n,m,u,v,w;
int head[maxn],nex[maxn<<1],to[maxn<<1],c[maxn<<1],kkk;
#define add(x,y,z) nex[++kkk]=head[x];\
head[x]=kkk; to[kkk]=y; c[kkk]=z;
int root,num,minx,siz[maxn],dis[maxn],mark[maxn];
void getnum(int x,int pre){
if(mark[x]) return; num++;
for(int i=head[x];i;i=nex[i])
if(to[i]!=pre) getnum(to[i],x);
return;
}
void getroot(int x,int pre){//找重心
if(mark[x]) return;
int maxx=0; siz[x]=1;
for(int i=head[x];i;i=nex[i])
if(to[i]!=pre){
getroot(to[i],x);
siz[x]+=siz[to[i]];
maxx=max(maxx,siz[to[i]]);
}
if(max(maxx,num-siz[x])<minx){
minx=max(maxx,num-siz[x]);
root=x;
}
return;
}
void getdis(int x,int pre){
if(mark[x]) return;
for(int i=head[x];i;i=nex[i])
if(to[i]!=pre){
dis[to[i]]=dis[x]+c[i];
getdis(to[i],x);
}
return;
}
int cnt,mem[maxn];
void find(int x,int pre){
if(mark[x]) return;
mem[++cnt]=dis[x];
for(int i=head[x];i;i=nex[i])
if(to[i]!=pre) find(to[i],x);
}
int calc(int x){//计算
cnt=0; find(x,-1);
sort(mem+1,mem+cnt+1);
int ans=0,l=1,r=cnt;
while(l<=r)
if(mem[l]+mem[r]<=m) ans+=r-l,l++;
else r--;
return ans;
}
int solve(int x){//得到答案
num=0; getnum(x,0); if(num==1) return 0;
minx=0x7f7f7f7f; getroot(x,0); dis[root]=0;
getdis(root,-1); int ans=calc(root); mark[root]=1;
for(int i=head[root];i;i=nex[i])
if(!mark[to[i]]) ans-=calc(to[i]);
for(int i=head[root];i;i=nex[i])
if(!mark[to[i]]) ans+=solve(to[i]);
return ans;
}
int main(){
//freopen("1.in","r",stdin);
//freopen(".out","w",stdout);
n=read();
if(n==0&&m==0) return 0;
for(int i=1;i<n;i++){
u=read(); v=read(); w=read();
add(u,v,w) add(v,u,w)
}
m=read();
printf("%d\n",solve(1));
return 0;
}