点分治 学习笔记

板子题

题目传送门
给定一棵 \(n\) 个节点的树,每条边有边权,求出树上两点距离小于等于 \(k\) 的点对数量。
\(n\le 4\times 10^4\)

算法解析

显然我们发现如果计算从每个节点开始的点对数量是 \(O(n^2)\) 的,显然是不行的,但是我们发现这是一个计数题,所以我们可以做点分治。
我们发现如果我们选取一个节点作为根,那么所有的节点就会分为两种:过根的路径和没有过根的路径,没有过根的路径我们可以把根删去再对每一棵子树计算答案,直到只剩下一个节点。
如果只计算过根的路径的话,首先我们先通过 DP 求出所有节点到根的距离,然后用排序用双指针法求出答案,再减去两端再同一棵子树的路径数量,就可以算出过根的路径总和;当然也可以用线段树来计算,这里不过多叙述,单次处理复杂度 \(O(n\log n)\)
但是我们发现,如果我们随机选一个节点最为根的话,如果我们每次都选了叶子节点,就会导致我们要去掉很多次才能算出答案,最坏情况下需要递归 \(n\) 次才可以求解,这样的复杂度是 \(O(n^2\log n)\) ,显然这不是最优的。
我们发现,如果每次将根选作这棵树的 重心 ,那么效率将会大大提升,因为当我们选取树的重心作为根的时候,去掉这个节点之后剩下的最大联通块的节点数是最小的,所以这样就会让一个树分成更小的部分,从而使分割的次数更少。所以说我们需要选择树的重心作为根来计算,只需要最多递归 \(\log n\) 次。这样算法复杂度就是 \(O(n\log^2n)\) 了。
为了方便,删除节点并不是真正的删除,显然只需要打一个标记即可,对应再代码里面就是 mark 数组。

代码

#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
#define maxn 40039
using namespace std;
//#define debug
typedef int Type;
inline Type read(){
	Type sum=0;
	int flag=0;
	char c=getchar();
	while((c<'0'||c>'9')&&c!='-') c=getchar();
	if(c=='-') c=getchar(),flag=1;
	while('0'<=c&&c<='9'){
		sum=(sum<<1)+(sum<<3)+(c^48);
		c=getchar();
	}
	if(flag) return -sum;
	return sum;
}
int n,m,u,v,w;
int head[maxn],nex[maxn<<1],to[maxn<<1],c[maxn<<1],kkk;
#define add(x,y,z) nex[++kkk]=head[x];\
head[x]=kkk; to[kkk]=y; c[kkk]=z;
int root,num,minx,siz[maxn],dis[maxn],mark[maxn];
void getnum(int x,int pre){
	if(mark[x]) return; num++;
	for(int i=head[x];i;i=nex[i])
	    if(to[i]!=pre) getnum(to[i],x);
	return;
}
void getroot(int x,int pre){//找重心
	if(mark[x]) return;
	int maxx=0; siz[x]=1;
	for(int i=head[x];i;i=nex[i])
	    if(to[i]!=pre){
	    	getroot(to[i],x);
	    	siz[x]+=siz[to[i]];
	    	maxx=max(maxx,siz[to[i]]);
		}
	if(max(maxx,num-siz[x])<minx){
		minx=max(maxx,num-siz[x]);
		root=x;
	}
	return;
}
void getdis(int x,int pre){
	if(mark[x]) return;
	for(int i=head[x];i;i=nex[i])
	    if(to[i]!=pre){
	    	dis[to[i]]=dis[x]+c[i];
	    	getdis(to[i],x);
		}
	return;
}
int cnt,mem[maxn];
void find(int x,int pre){
	if(mark[x]) return;
	mem[++cnt]=dis[x];
	for(int i=head[x];i;i=nex[i])
	    if(to[i]!=pre) find(to[i],x);
}
int calc(int x){//计算
	cnt=0; find(x,-1);
	sort(mem+1,mem+cnt+1);
	int ans=0,l=1,r=cnt;
	while(l<=r)
	    if(mem[l]+mem[r]<=m) ans+=r-l,l++;
	    else r--;
	return ans;
}
int solve(int x){//得到答案
	num=0; getnum(x,0); if(num==1) return 0;
	minx=0x7f7f7f7f; getroot(x,0); dis[root]=0;
	getdis(root,-1); int ans=calc(root); mark[root]=1;
	for(int i=head[root];i;i=nex[i])
	    if(!mark[to[i]]) ans-=calc(to[i]);
	for(int i=head[root];i;i=nex[i])
	    if(!mark[to[i]]) ans+=solve(to[i]);
	return ans;
}
int main(){
	//freopen("1.in","r",stdin);
	//freopen(".out","w",stdout);
	n=read();
	if(n==0&&m==0) return 0;
	for(int i=1;i<n;i++){
		u=read(); v=read(); w=read();
	    add(u,v,w) add(v,u,w)
	}
	m=read();
	printf("%d\n",solve(1));
	return 0;
}
posted @ 2021-08-24 08:51  jiangtaizhe001  阅读(32)  评论(0编辑  收藏  举报