点分治入门

点分治,顾名思义,对树上的节点进行分治。其实就是把一棵树拆成几棵分别处理。

要把一棵树分成几棵怎么做?显然选一个节点做根节点,分出它的子树就行。不难发现,这个点的选取对时间复杂度影响很大,去一个极端例子:一条链
image

如果选取1,5号节点,时间复杂度是O(n),而如果选取3号节点,时间复杂度是O(log n)影响整体时间复杂度的是分出的最大的子树,所以目标是使最大子树最小。

最大子树最小,不就是找重心吗?树的重心的定义就是其所有的子树中最大的子树节点数最少。怎么求?直接dfs,下面是代码:

void get_rt(int now,int fa){//now为当前根节点
	siz[now]=1;//子树大小
	son[now]=0;//最大子树的大小
	for(int i=head[now];i;i=nxt[i]){
		int to=nbr[i];
		if(to==fa||vis[to]) continue;//vis下面会用到,访问过了就不用再访问
		get_rt(to,now);
		siz[now]+=siz[to];
		son[now]=max(son[now],siz[to]);
	}
	son[now]=max(son[now],size-siz[now]);//当前节点若作为根节点,那么其父节点也是它的儿子
	if(son[now]<maxn){
		maxn=son[now];
		root=now;//更改重心
	}
}

下面是点分治的正式部分,先放代码:

void divide(int now){
	vis[now]=true;
	solve(now,1,0);//当前节点的答案 
	for(int i=head[now];i;i=nxt[i]){
		int to=nbr[i];
		if(vis[to]) continue;//访问过 
		solve(to,-1,edge[i]);//去重,既经过了now也经过v[now][i].to 
		maxn=0x3f3f3f3f,root=0,size=siz[to];//更新
		get_rt(to,0);//重新找重心 
		divide(root);
	}
}

先解释一下solve里的1,-1是为了方便处理第二个solve里的去重,计算答案的时候加-1。具体的solve因题而异。

计算now的贡献不是只要处理now就可以了,给一棵树你就明白为什么要去重了:
image
这个图里计算1的贡献时肯定会有一条1 2 5 6的路径,也有一条1 2 5的路径。最后算答案的时候肯定要拿两条链合并的,不然像2 1 3这种贡献就没法计算。那么1 2 5 6可以和1 2 5合并吗?显然不能,毕竟都在根节点的一棵子树里。至于solve的第三个参数。你考虑现在要在1的答案里减去子树2的答案。其实从12已经有了一个边权,长度从1->2的长度开始计算,也就是代码里的edge[i]

现在你已经会了点分治的基本操作,来看一道板题:

eg1:

P3806 【模板】点分治1
题意:给定一棵有n个点的树,询问树上距离为k的点对是否存在。

先看数据加强前的代码,主要就是solve函数不同,我们找一个桶存当前点所有的链的长度,双层循环匹配任意两条链。

void query(int now,int fa,int use){//dfs找出所有链的长度
	 stk[++top]=use;
	for(int i=head[now];i;i=nxt[i]){
		int to=nbr[i];
		if(vis[to] || to==fa) continue;
		query(to,now,use+edge[i]);
	}
}

void solve(int now,int f,int use){
	top=0;
	query(now,0,use);
	for(int i=1;i<top;i++){
		for(int j=i+1;j<=top;j++){
			ans[stk[i]+stk[j]]+=f;//ans数组不能用bool,还有删除
		}
	}
}

下面是完整的代码:

Talk is cheap, show me your code.(点我看代码)
#include<bits/stdc++.h>
using namespace std;

const int N=10005;
int ans[10000005];
int nbr[N<<1],head[N],nxt[N<<1],edge[N<<1];
int siz[N],son[N],stk[N];bool vis[N];
int n,m,size,maxn,root,tot,top;

inline int read()
{
char ch=getchar();bool f=0;int x=0;
for(;!isdigit(ch);ch=getchar())if(ch=='-')f=1;
for(;isdigit(ch);ch=getchar())x=(x<<1)+(x<<3)+(ch^48);
if(f==1)x=-x;return x;
}

inline void add(int from,int to,int val){
	nbr[++tot]=to,nxt[tot]=head[from],head[from]=tot,edge[tot]=val;
	nbr[++tot]=from,nxt[tot]=head[to],head[to]=tot,edge[tot]=val;
}

void get_rt(int now,int fa){
	//cout<<now<<" "<<fa<<endl;
	siz[now]=1;//子树大小
	son[now]=0;//最大子树的大小
	for(int i=head[now];i;i=nxt[i]){
		int to=nbr[i];
		if(to==fa||vis[to]) continue;
		get_rt(to,now);
		siz[now]+=siz[to];
		son[now]=max(son[now],siz[to]);
	}
	son[now]=max(son[now],size-siz[now]);
	if(son[now]<maxn){
		maxn=son[now];
		root=now;//更改重心
	}
}

void query(int now,int fa,int use){
	 stk[++top]=use;
	for(int i=head[now];i;i=nxt[i]){
		int to=nbr[i];
		if(vis[to] || to==fa) continue;
		query(to,now,use+edge[i]);
	}
}

void solve(int now,int f,int use){
	top=0;
	query(now,0,use);
	for(int i=1;i<top;i++){
		for(int j=i+1;j<=top;j++){
			ans[stk[i]+stk[j]]+=f;
		}
	}
}

void divide(int now){
	vis[now]=true;
	solve(now,1,0);//当前节点的答案 
	for(int i=head[now];i;i=nxt[i]){
		int to=nbr[i];
		if(vis[to]) continue;//访问过 
		solve(to,-1,edge[i]);//去重,既经过了now也经过v[now][i].to 
		maxn=0x3f3f3f3f,root=0,size=siz[to];//更新
		get_rt(to,0);//重新找重心 
		divide(root);
	}
}

int main(){
	freopen("test.in","r",stdin);
	//freopen("test.out","w",stdout);
	n=read();
	m=read();
	for(int i=1;i<n;i++){
		int u=read(),v=read(),w=read();
		add(u,v,w);
	}
	root=0;
	maxn=0x3f3f3f3f;
	size=n;
	get_rt(1,0);
	divide(root);
	//cout<<"UES"<<endl;
	while(m--){
		int k;
		k=read();
		puts(ans[k]?"AYE":"NAY");
	}
	return 0;
}

有一个细节就是树的所有边权权值总和达到了1e8,所以ans数组要开到1e8

然后你惊讶的发现:

image

这是必然的,本来solve就是n2级别的,还带一个log的常数,肯定跑不了1e4

你考虑这样一件事,就是说:时间复杂度到底为什么会多,不就是两条链长度之和大量的重复,这是完全不必要的,因为题目只在乎有没有长度为k的点对,只要解决去重的问题就好办了。

考虑如果没有去重,可以根据到当前节点的距离排序,可以类似双指针解决。当前两条链的长度之和短了左端点右移,长了右端点左移,把询问离线下来复杂度一次log n总复杂度O(nlogn)

只剩最后一个问题,怎么去重?只要多维护一个数组维护每一个在当前节点子树内的节点在当前节点的哪一个子树内就可以了。

点击查看代码
#include<bits/stdc++.h>
using namespace std;

const int N=10005;
const int M=105;
int ans[10000005];
int nbr[N<<1],head[N],nxt[N<<1],edge[N<<1];
int siz[N],son[N],stk[N];bool vis[N];
int n,m,size,root,tot,top;
int a[N],d[N],b[N];
bool ok[N];
int que[M];

inline int read()
{
char ch=getchar();bool f=0;int x=0;
for(;!isdigit(ch);ch=getchar())if(ch=='-')f=1;
for(;isdigit(ch);ch=getchar())x=(x<<1)+(x<<3)+(ch^48);
if(f==1)x=-x;return x;
}

inline void add(int from,int to,int val){
	nbr[++tot]=to,nxt[tot]=head[from],head[from]=tot,edge[tot]=val;
	nbr[++tot]=from,nxt[tot]=head[to],head[to]=tot,edge[tot]=val;
}

bool cmp(int x,int y){
	return d[x]<d[y];
}

void get_rt(int now,int fa){
	//cout<<now<<" "<<fa<<endl;
	siz[now]=1;//子树大小
	son[now]=0;//最大子树的大小
	for(int i=head[now];i;i=nxt[i]){
		int to=nbr[i];
		if(to==fa||vis[to]) continue;
		get_rt(to,now);
		siz[now]+=siz[to];
		son[now]=max(son[now],siz[to]);
	}
	son[now]=max(son[now],size-siz[now]);
	if(!root || son[now]<son[root]){
		root=now;//更改重心
	}
}

void query(int now,int fa,int use,int from){
	a[++top]=now;//在当前节点子树内的节点 
	d[now]=use;//距离 
	b[now]=from;//所在子树 
	for(int i=head[now];i;i=nxt[i]){
		int to=nbr[i];
		if(vis[to] || to==fa) continue;
		query(to,now,use+edge[i],from);
	}
}

void solve(int now){
	top=0;
	a[++top]=now;//别忘了加自己 
	d[now]=0;
	b[now]=now;
	for(int i=head[now];i;i=nxt[i]){
		int to=nbr[i];
		if(vis[to]) continue;
		query(to,now,edge[i],to);
	}
	sort(a+1,a+1+top,cmp);//排序 
	for(int i=1;i<=m;i++){
		int l=1,r=top;
		if(ok[i]) continue;
		while(l<r){
			if(d[a[l]]+d[a[r]]>que[i]) r--;//右端点左移 
			else if(d[a[l]]+d[a[r]]<que[i]) l++;//左端点右移 
			else if(b[a[l]]==b[a[r]]){//在同一棵子树里 
				if(d[a[r]]==d[a[r-1]]) r--;
				else l++;
			}
			else{
				ok[i]=true;//第i个询问结果为真 
				break;
			}
		}
	}
}

void divide(int now){
	vis[now]=true;
	solve(now);//当前节点的答案 
	for(int i=head[now];i;i=nxt[i]){
		int to=nbr[i];
		if(vis[to]) continue;//访问过 
		root=0,size=siz[to];//更新
		get_rt(to,0);//重新找重心 
		divide(root);
	}
}

int main(){
	//freopen("test.in","r",stdin);
	//freopen("test.out","w",stdout);
	n=read();
	m=read();
	for(int i=1;i<n;i++){
		int u=read(),v=read(),w=read();
		add(u,v,w);
	}
	root=0;
	size=n;
	//cout<<"UES"<<endl;
	for(int i=1;i<=m;i++){
		que[i]=read();
		if(!que[i]) ok[i]=true;
	}
	get_rt(1,0);
	divide(root);
	for(int i=1;i<=m;i++){
		if(ok[i]) cout<<"AYE"<<endl;
		else cout<<"NAY"<<endl;
	}
	return 0;
}

eg2:

P4178 Tree
题意:给定一棵n个节点的树,每条边有边权,求出树上两点距离小于等于k的点对数量。

跟板题很像,只不过把恰好等于k变成了小于等于k。能不能从板题的正解改过来呢?作者没有想法,但是仔细想一想,板题的60分做法好像很好改。首先这题也适用减去不合法的路径的方法。其次查询答案变得很好操作了。因为单调,直接双指针做完了。

点击查看代码
#include<bits/stdc++.h>
#define int long long
using namespace std;

int ans;
const int N=40005;
int nbr[N<<1],head[N],nxt[N<<1],edge[N<<1];
int siz[N],son[N],stk[N];bool vis[N];
int n,m,size,root,tot,maxn,top;
bool ok[N];

inline int read()
{
char ch=getchar();bool f=0;int x=0;
for(;!isdigit(ch);ch=getchar())if(ch=='-')f=1;
for(;isdigit(ch);ch=getchar())x=(x<<1)+(x<<3)+(ch^48);
if(f==1)x=-x;return x;
}

inline void add(int from,int to,int val){
	nbr[++tot]=to,nxt[tot]=head[from],head[from]=tot,edge[tot]=val;
	nbr[++tot]=from,nxt[tot]=head[to],head[to]=tot,edge[tot]=val;
}

void get_rt(int now,int fa){
	//cout<<now<<" "<<fa<<endl;
	siz[now]=1;//子树大小
	son[now]=0;//最大子树的大小
	for(int i=head[now];i;i=nxt[i]){
		int to=nbr[i];
		if(to==fa||vis[to]) continue;
		get_rt(to,now);
		siz[now]+=siz[to];
		son[now]=max(son[now],siz[to]);
	}
	son[now]=max(son[now],size-siz[now]);
	if(son[now]<maxn){
		maxn=son[now];
		root=now;//更改重心
	}
}

void query(int now,int fa,int use){
	stk[++top]=use;
	for(int i=head[now];i;i=nxt[i]){
		int to=nbr[i];
		if(vis[to] || to==fa) continue;
		query(to,now,use+edge[i]);
	}
}

void solve(int now,int f,int use){
	top=0;
	query(now,0,use);
	sort(stk+1,stk+1+top);//排序 
	int l=1,r=top;
	while(l<r){
		if(stk[l]+stk[r]<=m){
			ans+=(f*(r-l));
			l++;
		}
		else r--;
	}
}

void divide(int now){
	vis[now]=true;
	solve(now,1,0);//当前节点的答案 
	for(int i=head[now];i;i=nxt[i]){
		int to=nbr[i];
		if(vis[to]) continue;//访问过 
		solve(to,-1,edge[i]);
		maxn=0x3f3f3f3f,root=0,size=siz[to];//更新
		get_rt(to,0);//重新找重心 
		divide(root);
	}
}

signed main(){
	//freopen("test.in","r",stdin);
	//freopen("test.out","w",stdout);
	n=read();
	for(int i=1;i<n;i++){
		int u=read(),v=read(),w=read();
		add(u,v,w);
	}
	root=0;
	size=n;
	maxn=0x3f3f3f3f;
	m=read();
	get_rt(1,0);
	divide(root);
	cout<<ans;
	return 0;
}
posted @   星河倒注  阅读(61)  评论(2编辑  收藏  举报
相关博文:
阅读排行:
· winform 绘制太阳,地球,月球 运作规律
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· AI与.NET技术实操系列(五):向量存储与相似性搜索在 .NET 中的实现
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
点击右上角即可分享
微信分享提示