点分治

点分治一共学了三次。第一次是初三暑假,看着作业里的“【模板】点分治”感到望而生畏,稀里糊涂依葫芦画瓢交了一个AC,后面就忘光了。第二次是去年六月份,一知半解地写了一个非主流的写法,交上去自己现在都看不懂。第三次就是现在了,感觉主流做法其实很好理解,不知道为什么以前学得那么头疼。故写个博客加深记忆。

功能

解决许多以树上路径为对象的统计类问题。

流程

主流写法共有四大函数,get_rootget_discalcwork

void get_root(int x,int p,int &root,int all);//初始root=0,mxsiz[0]=INF
void get_dis(int x,int p,int dis,int fr);
void calc(int x);
void work(int x);

点分治的流程是:
每次选定一个分治中心 \(x\),需满足 \(x\) 是重心,然后把 \(x\) 标记为删除,形成若干个连通块,递归处理。
每次选择分治中心 \(x\) 后,统计连通块内的所有路径 \((u,v),u\ne x,v\ne x,b_u\ne b_v\) 的答案(通过各种方式,不需要 \(O(siz^2)\)\(b_u\) 表示删掉 \(x\)\(u\) 属于 \(x\) 的哪个邻居的连通块),以及所有路径 \((x,u),u\ne x\) 的答案。只包含一个点的路径的答案通常可以通过特判的方式解决。
get_root 的作用是找到当前连通块的分治中心 rootall 是当前连通块的大小。
get_dis 的作用是遍历连通块内每个 \(u\ne x\) 并将其存到一个数组 a[1~tot] 中、记录 d[u],b[u] 表示 dist(u,x)\(u\) 属于 \(x\) 的哪个邻居的管辖范围。
calc 最重要了,它通常需要对这 \(siz-1\) 个点在 \(O(siz\text{polylog} n)\) 的复杂度内统计不在 \(x\) 同一邻居管辖范围内的两点的点对的贡献。
work 的作用就是 dfs,每次调用 get_root 找到分裂出的连通块的分治中心,从这个分治中心递归下去 work,每到达一个分治中心就调用一次 calc

例:【JOISC2021】会合2

给出一棵大小为 \(N\) 的树。

对于树上的一个点集 \(S\),定义其权值为满足 \(\sum\limits_{u\in S} dis(u,x)\) 取到最小值的 \(x\) 的数量。

对于每一个 \(1\le i \le N\),求出当点集大小为 \(i\) 时的最大权值。

\(1\le N \le 2\times 10^5\)

看到 \(\sum dis\) 最小,想到带权重心。带权重心构成一条链。更确切地,当 \(|S|\) 为偶数时构成一条链,为奇数时构成一个单点(也就是 \(i\) 为奇数时输出 \(1\) 即可)。这条链只需要满足从链的中间某点把树吊起来后,链首 \(s\) 和链尾 \(t\)\(siz'\)\(\ge k\),就能用 \(dis(u,v)\) 更新 \(ans_{2k,2k-2,...,2}\)

现在我们只需要对树上每条路径 \((s,t)\),用 \(dis(s,t)\) 更新 \(ans[2\min(siz'_s,siz'_t)]\) 了。自然使用点分治解决。

calc 如何实现呢?
我们考虑从 \(a_1\to a_tot\) 枚举 \(u\),每次查询 \(siz'\ge siz'_u\) 的不与 \(u\)\(b\) 的点的 \(dis\) 最大值(用一棵以 \(siz'\) 为下标的 BIT 维护后缀 \(\max\) 维护)。从 \(a_tot\sim a_1\) 再做一遍。

细节:

  1. \(siz'\) 怎么求呢?先把所有点的 \(siz'\) 设为在原树中的 \(siz\),然后暴力跳 \(x\) 的祖先(直到碰到一个已删除的点为止),把经过的点的 \(siz'\) 设为 \(n-siz_{las}\)\(las\) 为上一个跳到的点,初始为 \(x\)
  2. \((x,u),u\ne x\) 类路径怎么统计?单独统计(别跟 \((u,v)\) 类一起统计)。当 \(b_u\ne fa_x\) 时,\(x\)\(siz'\)\(n-siz_{b_u}\),否则取 \(siz_x\)
//start coding at 9:52
#include <bits/stdc++.h>
using namespace std;
inline int read(){
	int x=0;char ch=getchar();
	while(ch<'0'||ch>'9')ch=getchar();
	while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
	return x;
}
const int N=2e5+5;
int n,tot,osiz[N],siz[N],ss[N],d[N],b[N],fa[N],mxsiz[N],a[N],res[N];
int c[N];
bool vis[N];
vector<int>G[N];
void odfs(int x,int p){
	fa[x]=p,osiz[x]=1;
	for(int y:G[x])if(y^p){
		odfs(y,x);
		osiz[x]+=osiz[y];
	}
}
void get_root(int x,int p,int &root,int all){
	siz[x]=1,mxsiz[x]=0;
	for(int y:G[x])if(y!=p&&!vis[y]){
		get_root(y,x,root,all);
		siz[x]+=siz[y];
		mxsiz[x]=max(mxsiz[x],siz[y]);
	}
	mxsiz[x]=max(mxsiz[x],all-siz[x]);
	if(mxsiz[x]<mxsiz[root])root=x;//cerr<<mxsiz[x]<<' '<<mxsiz[root]<<root;
//	cerr<<all<<'.';
}
void get_dis(int x,int p,int dep,int fr){
	d[x]=dep,b[x]=fr,a[++tot]=x;
	for(int y:G[x])if(y!=p&&!vis[y])get_dis(y,x,dep+1,fr);
}
void add(int x,int y){
	x=n-x+1;
	for(;x<=n;x+=x&-x)c[x]=max(c[x],y);
}
int ask(int x){
	x=n-x+1;
	int ret=-1e9;
	for(;x;x-=x&-x)ret=max(ret,c[x]);
	return ret;
}
void chexiao(int x){
	x=n-x+1;
	for(;x<=n;x+=x&-x)c[x]=-1e9;
}
void calc(int x){
	tot=0;
	for(int y:G[x])if(!vis[y])get_dis(y,x,1,y);
	for(int i=1;i<=tot;i++)ss[a[i]]=osiz[a[i]];
	for(int i=fa[x],las=x;i&&!vis[i];i=fa[i])ss[i]=n-osiz[las],las=i;
	int j=0;
	for(int i=1;i<=tot;i++){
		while(j+1<=i&&b[a[j+1]]!=b[a[i]])j++,add(ss[a[j]],d[a[j]]);
		res[ss[a[i]]]=max(res[ss[a[i]]],d[a[i]]+ask(ss[a[i]])+1);
	//	cerr<<a[i]<<'|'<<ss[a[i]]<<'|'<<j<<'.'<<ask(ss[a[i]])+d[a[i]]+1<<'\n';
	}
	for(int i=1;i<=j;i++)chexiao(ss[a[i]]);
	j=tot+1;
	for(int i=tot;i>=1;i--){
		while(j-1>=i&&b[a[j-1]]!=b[a[i]])j--,add(ss[a[j]],d[a[j]]);
		res[ss[a[i]]]=max(res[ss[a[i]]],d[a[i]]+ask(ss[a[i]])+1);
	}
	for(int i=j;i<=tot;i++)chexiao(ss[a[i]]);
	for(int i=1;i<=tot;i++){
		if(b[a[i]]==fa[x])res[min(ss[a[i]],osiz[x])]=max(res[min(ss[a[i]],osiz[x])],d[a[i]]+1);
		else res[min(ss[a[i]],n+1-osiz[b[a[i]]])]=max(res[min(ss[a[i]],n+1-osiz[b[a[i]]])],d[a[i]]+1);
	}
//	cerr<<x<<":\n";
//	for(int i=1;i<=tot;i++)cerr<<a[i]<<' '<<ss[a[i]]<<'\n';cerr<<'\n';
}
void work(int x){
	calc(x);
	vis[x]=1;
	for(int y:G[x])if(!vis[y]){
		int rt=0;
		get_root(y,x,rt,siz[y]);
		work(rt);
	}
}
int main(){
	n=read();
	for(int i=1,u,v;i<n;i++)u=read(),v=read(),G[u].emplace_back(v),G[v].emplace_back(u);
	odfs(1,0);
	mxsiz[0]=1e9;
	int rt=0;get_root(1,0,rt,n);//cerr<<mxsiz[1]<<'.';
	for(int i=1;i<=n;i++)c[i]=-1e9;
	work(rt);
	for(int i=n-1;i;i--)res[i]=max(res[i],res[i+1]);
	for(int i=1;i<=n;i++){
		if(i&1)puts("1");
		else cout<<max(1,res[i/2])<<'\n';
	}
	return 0;
}
posted @ 2023-04-20 13:05  pengyule  阅读(30)  评论(0编辑  收藏  举报