点分治学习笔记

一、概述

前置知识:树的重心。

1. 经典应用 1

假设我们要统计一棵有 \(n\) 个节点的树上所有点对之间距离是 \(k\) 的有多少对。注意树上的边有长度。

\(n,k\le 10^6\)

一个朴素的算法是遍历树上的所有点对,处理出距离(也就是链的长度)。

时间复杂度 \(O(n^2)\)

考虑优化。由于只有一次查询,直接求出所有的点对距离太没有必要了。

所以考虑进行分治。

算法流程如下:

  1. 选择树的重心 \(x\),它将整棵树分成若干部分,计算所有链中两个端点分别在 \(x\) 的两棵不同子树的贡献。

    1. 具体来说,遍历 \(x\) 的所有子树,开 2 个桶 \(a\)\(b\)\(a_i\) 负责装 \(x\) 当前子树里,一个端点为 \(x\) 的长度为 \(i\) 的链有多少条。\(b_i\) 负责装 \(x\) 当前子树之前遍历过的所有子树里,一个端点为 \(x\) 的长度为 \(i\) 的链有多少条。
    2. 初始时 \(a_0\) 为 1,因为可能有一些长度为 \(k\) 的链其中一个端点就是 \(x\)
    3. 处理当前子树时,我们枚举 \(b\) 里的每一个不为 \(0\) 的值,其下标为 \(p\),那么答案要加上 \(b_p\times a_{k-p}\)
    4. 当我们处理完当前子树后,将 \(a\) 桶倒进 \(b\) 桶里,并且将 \(a\) 桶清空。这样可以做到不重不漏。
  2. 删掉点 \(x\),继续递归地考虑 \(x\) 的所有子树。

正确性是因为一个在经过其中一级重心的链,不会在上级重心中被算过,也不会在下级重心再被算到。

让我们来分析一下时间复杂度。

首先,由于重心的性质,每次删除全树的重心,最大的子树大小至多为原来的一半。所以一棵有 \(n\) 个节点的树,将被分治 \(\log n\) 层,每一层所处理的所有子树大小之和小于 \(n\)

所以如果我们可以实现对于每一棵大小为 \(s\) 的子树 \(O(s)\) 处理,我们就可以在总 \(O(n\log n)\) 的时间复杂度内解决这个问题。

处理桶 \(a\) 时,每一次都要用数组记录一个端点为 \(x\),另一个端点为当前子树中的点的 \(s\) 条链的长度(时间复杂度 \(O(s)\)),然后再把这些链的长度在桶里对应的计数器加一。这样,我们计算贡献和清空桶时直接遍历这些数组就知道桶里的哪些计数器不为 \(0\) 了。所以计算贡献和清空桶的复杂度也是 \(O(s)\)

那么我们就得到了一个 \(O(n\log n)\) 的做法。

加强版:如果这一题的 \(k\le 1145141919810\) 呢?可以考虑用 2 个 gp_hash_tablecc_hash_table(pbds)代替数组作为桶 \(a\)\(b\),或者手写哈希。用 set 或者 map 的话会多一个 \(\log\),十分不值得。

2. 经典应用 2

假设我们要统计一棵有 \(n\) 个节点的树上所有点对之间距离不超过 \(k\) 的有多少对。注意树上的边有长度。

\(n\le 2\times 10^5,k\le 1145141919810\)

注意到这一题和上一题类似,我们照样点分治就可以了,但是在计算贡献的时候出了一些问题。

我们是要求长度小于等于 \(k\) 的链,那么在统计贡献的时候就必须计算 \(b_p\times\sum\limits_{q=0}\limits^{k-p}a_q\)。那么显然,用桶的时间复杂度是 \(O(n)\) 的。所以为了复杂度不退化,我们可以考虑搞一下容斥:

  1. 用一个数组记录 \(x\) 的子树内的所有点距离 \(x\) 的距离,然后排个序,用双指针法计算出所有方法数之和。
  2. 以上计算可能会统计一些不可能的贡献,比如链的两个的端点在 \(x\) 的同一子树内,所以我们枚举 \(x\) 的所有子节点 \(y\),再用答案减去关于 \(y\) 的上述结果。

这样不用考虑 \(k\) 的范围了。但是处理每棵子树时要把距离数组排个序,也就是 \(O(s\log s)\),所以总时间复杂度是 \(O(n\log^2 n)\)。当然常数很小,容易知道跑不满。

有一些细节,代码实现详见下文。

二、例题

1. 【模板】点分治1

注意这一题有多次询问,而且每次找到的重心都完全相同,所以可以递归一次,一起处理。

因为这题只考虑可行性,所以用 C++ 20 的 unordered_set 作为桶。代码不到 1.2 kb。

点击查看代码
#include<bits/stdc++.h>
#include<unordered_set>
using namespace std;
typedef long long ll;
const ll o=20010;
unordered_set<ll>C,D;
ll n,m,c=0,r,a[o],q[o],nxt[o],h[o],t[o],v[o],s[o],p[o],d[o];
inline void add(ll x,ll y,ll z){
	nxt[++c]=h[x];h[x]=c;t[c]=y;v[c]=z;
}
void R(ll x,ll f,ll T){
	s[x]=1;p[x]=0;
	for(ll i=h[x],y;i;i=nxt[i])
		if(!d[y=t[i]]&&y!=f)R(y,x,T),s[x]+=s[y],p[x]=max(p[x],s[y]);
	if((p[x]=max(p[x],T-s[x]))<p[r])r=x;
}
void W(ll x,ll f,ll ds){
	D.insert(ds);
	for(ll i=h[x],y;i;i=nxt[i])
		if(!d[y=t[i]]&&y!=f)W(y,x,ds+v[i]);
}
void Q(ll x,ll f){
	C.insert(0);d[x]=1;
	for(ll i=h[x],y;i;i=nxt[i])
		if(!d[y=t[i]]&&y!=f){
			W(y,x,v[i]);
			for(auto j:D)
				for(ll k=1;k<=m;k++)
					if(q[k]>=j)a[k]|=C.contains(q[k]-j);
			for(auto j:D)C.insert(j);
			D.clear();
		}
	C.clear();
	for(ll i=h[x],y;i;i=nxt[i])
		if(!d[y=t[i]]&&y!=f)p[r=0]=1e9,R(y,x,s[y]),R(r,0,s[y]),Q(r,x);
}
int main(){
	scanf("%lld%lld",&n,&m);
	for(ll i=1,x,y,z;i<n;i++){
		scanf("%lld%lld%lld",&x,&y,&z);
		add(x,y,z);add(y,x,z);
	}
	for(ll i=1;i<=m;i++)scanf("%lld",&q[i]);
	p[r=0]=1e9;R(1,0,n);R(r,0,n);Q(r,0);
	for(ll i=1;i<=m;i++)puts(a[i]?"AYE":"NAY");
	return 0;
}

2. Tree

这就是经典应用 2。所以可以用容斥。

但是我做这道题的时候 Too young too simple,sometimes naive,所以手写了一个平衡树。时间复杂度也是 \(O(n\log^2 n)\)

点击查看代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll o=80010;
inline ll read(){
	ll x=0,f=1;char ch=getchar();
	while(ch>'9'||ch<'0'){if(ch=='-')f=0;ch=getchar();}
	while('0'<=ch&&ch<='9'){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
	return f?x:-x;
}
struct FHQTreap{
	struct bst{ll l,r,s,v,p;}t[o];
	ll rt=0,cnt=0;
	inline void C(){for(ll i=1;i<=cnt;i++)t[i]=bst{0,0,0,0,0};rt=cnt=0;}
	inline void N(ll x){t[++cnt]=bst{0,0,1,x,rand()};}
	inline void U(ll x){t[x].s=t[t[x].l].s+t[t[x].r].s+1;}
	inline void S(ll x,ll k,ll &l,ll &r){
		if(!x){l=r=0;return;}
		if(t[x].v<=k){l=x;S(t[x].r,k,t[x].r,r);}
		else{r=x;S(t[x].l,k,l,t[x].l);}
		U(x);
	}
	inline ll M(ll l,ll r){
		if(!l||!r)return l+r;
		if(t[l].p<=t[r].p){t[l].r=M(t[l].r,r);U(l);return l;}
		else{t[r].l=M(l,t[r].l);U(r);return r;}
	}
	inline void I(ll x){ll l,r;S(rt,x,l,r);N(x);rt=M(M(l,cnt),r);}
	inline ll G(ll x){ll l,r,p=0;S(rt,x,l,r);p=t[l].s;rt=M(l,r);return p;}
}C,D; 
ll n,m,c=0,r,nxt[o],h[o],t[o],v[o],s[o],p[o],d[o],q,ans=0;
inline void add(ll x,ll y,ll z){nxt[++c]=h[x];h[x]=c;t[c]=y;v[c]=z;}
inline void R(ll x,ll f,ll T){
	s[x]=1;p[x]=0;
	for(ll i=h[x],y;i;i=nxt[i])
		if(!d[y=t[i]]&&y!=f)R(y,x,T),s[x]+=s[y],p[x]=max(p[x],s[y]);
	if(p[r]>(p[x]=max(p[x],T-s[x])))r=x;
}
inline void W(ll x,ll f,ll Z){
	D.I(Z);
	for(ll i=h[x],y;i;i=nxt[i])
		if(!d[y=t[i]]&&y!=f)W(y,x,Z+v[i]);
}
inline void Q(ll x,ll f){
	C.I(0);d[x]=1;
	for(ll i=h[x],y;i;i=nxt[i])
		if(!d[y=t[i]]&&y!=f){
			W(y,x,v[i]);
			for(ll j=1;j<=D.cnt;j++)
				if(q>=D.t[j].v)ans+=C.G(q-D.t[j].v);
			for(ll j=1;j<=D.cnt;j++)C.I(D.t[j].v);
			D.C();
		}
	C.C();
	for(ll i=h[x],y;i;i=nxt[i])
		if(!d[y=t[i]]&&y!=f)p[r=0]=1e9,R(y,x,s[y]),Q(r,x);
}
int main(){
	srand(time(0));n=read();
	for(ll i=1,x,y,z;i<n;i++){
		x=read();y=read();z=read();
		add(x,y,z);add(y,x,z);
	}
	q=read();p[r=0]=1e9;R(1,0,n);Q(r,0);cout<<ans<<'\n';
	return 0;
}

容斥写法

点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int _=80010;
int cnt=0,nxt[_],to[_],v[_],h[_],ans=0,n,k,del[_],sz[_],p[_],dis[_],rt;
inline void add(int x,int y,int z){
	nxt[++cnt]=h[x];to[cnt]=y;v[cnt]=z;h[x]=cnt;
}
inline void getroot(int x,int fa,int tot){
	sz[x]=1;p[x]=0;
	for(int i=h[x],y;i;i=nxt[i])
		if(!del[y=to[i]]&&y!=fa){
			getroot(y,x,tot);
			sz[x]+=sz[y];
			p[x]=max(p[x],sz[y]);
		}
	p[x]=max(p[x],tot-sz[x]);
	if(p[x]<p[rt])rt=x;
}
inline void getdis(int x,int fa,int dist){
	dis[++dis[0]]=dist;
	for(int i=h[x],y;i;i=nxt[i])
		if(!del[y=to[i]]&&y!=fa)getdis(y,x,dist+v[i]);
}
inline int cal(int x,int dist){
	dis[0]=0;
	getdis(x,0,dist);
	sort(dis+1,dis+dis[0]+1);
	int res=0,l=1,r=dis[0];
	while(l<r)
		if(dis[l]+dis[r]<=k)res+=r-l,l++;
		else r--;
	return res;
}
inline void solve(int x){
	del[x]=1;
	ans+=cal(x,0);
	for(int i=h[x],y;i;i=nxt[i])
		if(!del[y=to[i]]){
			ans-=cal(y,v[i]);
			rt=0;
			getroot(y,x,sz[y]);
			solve(rt);
		}
} 
int main(){
	scanf("%d",&n);
	for(int i=1,x,y,z;i<n;i++){
		scanf("%d%d%d",&x,&y,&z);
		add(x,y,z);add(y,x,z);
	}
	scanf("%d",&k);
	p[rt=0]=1e9;
	getroot(1,0,n);
	solve(rt);
	printf("%d\n",ans);
	return 0;
}

3. [国家集训队]聪聪可可

我们维护两个大小为 3 的桶,\(C\)\(D\),代表子树内长度 \(\mod3=i\) 的链有多少条,然后统计的时候就是 \(D_2\times C_1+D_1\times C_2+D_0\times C_0+D_0\)(可能链有一端就是 \(x\)),然后把 \(D\) 倒进 \(C\) 即可。

点击查看代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll o=80010;
inline ll read(){
	ll x=0,f=1;char ch=getchar();
	while(ch>'9'||ch<'0'){if(ch=='-')f=0;ch=getchar();}
	while('0'<=ch&&ch<='9'){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
	return f?x:-x;
}
ll nxt[o],h[o],t[o],v[o],d[o],s[o],p[o],C[3],D[3],n,c=0,r,ans=0;
inline void out(ll x){
	ll X=x*2+n,Y=n*n,G=__gcd(X,Y);
	cout<<X/G<<'/'<<Y/G<<'\n';
}
inline void add(ll x,ll y,ll z){nxt[++c]=h[x];h[x]=c;t[c]=y;v[c]=z;}
inline void R(ll x,ll f,ll T){
	s[x]=1;p[x]=0;
	for(ll i=h[x],y;i;i=nxt[i])
		if(!d[y=t[i]]&&y!=f)R(y,x,T),s[x]+=s[y],p[x]=max(p[x],s[y]);
	if(p[r]>(p[x]=max(p[x],T-s[x])))r=x;
}
inline void W(ll x,ll f,ll T){
	D[T%3]++;
	for(ll i=h[x],y;i;i=nxt[i])
		if(!d[y=t[i]]&&y!=f)W(y,x,T+v[i]);
}
inline void Q(ll x,ll f){
	d[x]=1;
	for(ll i=h[x],y;i;i=nxt[i])
		if(!d[y=t[i]]&&y!=f){
			W(y,x,v[i]);
			ans+=D[0]+D[0]*C[0]+D[1]*C[2]+D[2]*C[1];
			for(ll j=0;j<3;j++)C[j]+=D[j],D[j]=0;
		}
	C[0]=C[1]=C[2]=0;
	for(ll i=h[x],y;i;i=nxt[i])
		if(!d[y=t[i]]&&y!=f)p[r=0]=1e9,R(y,x,s[y]),Q(r,x);
}
int main(){
	n=read();
	for(ll i=1,x,y,z;i<n;i++){
		x=read();y=read();z=read();
		add(x,y,z);add(y,x,z);
	}
	p[r=0]=1e9;R(1,0,n);Q(r,0);
	out(ans);
	return 0;
}

4. [IOI2011]Race

这一题一看就是点分治,但是怎么维护呢?首先 \(k\le 10^6\),我们就不用写哈希了,直接用数组维护即可。然后就是要求最小边数量。于是我弄了两个数组 \(C\)\(D\) 代表子树中长度为 \(i\) 的链最少多少条边。然后用数组 \(E\) 记录桶里哪些长度不是初始值。由于要求链上最少有多少条边,所以一开始 \(C\)\(D\) 都是无穷大。计算链长度的时候,另外记录边的条数,将长度扔进 \(E\) 里,再扔进 \(D\) 里(即更新同样长度的边数的最小值)。然后清空就是遍历 \(E\) 数组,将 \(C\)\(D\) 桶的对应位置赋值成无穷大。

点击查看代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll o=400010,z=1000010,I=1000000000;
inline ll read(){
	ll x=0,f=1;char ch=getchar();
	while(ch>'9'||ch<'0'){if(ch=='-')f=0;ch=getchar();}
	while('0'<=ch&&ch<='9'){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
	return f?x:-x;
}
ll n,k,S,K,A=I,c=0,r,p[o],nxt[o],h[o],t[o],v[o],s[o],d[o],C[z],D[z],E[o];
inline void add(ll x,ll y,ll z){nxt[++c]=h[x];h[x]=c;t[c]=y;v[c]=z;}
inline void R(ll x,ll f,ll T){
	s[x]=1;p[x]=0;
	for(ll i=h[x],y;i;i=nxt[i])
		if(!d[y=t[i]]&&y!=f)R(y,x,T),s[x]+=s[y],p[x]=max(p[x],s[y]);
	if(p[r]>(p[x]=max(p[x],T-s[x])))r=x;
}
inline void W(ll x,ll f,ll T,ll N){
    if(T>k)return;
	E[++K]=T;D[T]=min(D[T],N);
	for(ll i=h[x],y;i;i=nxt[i])
		if(!d[y=t[i]]&&y!=f&&T+v[i]<=k)W(y,x,T+v[i],N+1);
}
inline void Q(ll x,ll f){
	d[x]=1;S=K=0;
	for(ll i=h[x],y;i;i=nxt[i])
		if(!d[y=t[i]]&&y!=f&&v[i]<=k){
			W(y,x,v[i],1);
			for(ll j=S+1;j<=K;j++)
				if(E[j]<=k)A=min(A,D[E[j]]+C[k-E[j]]);
			for(ll j=S+1;j<=K;j++)C[E[j]]=min(C[E[j]],D[E[j]]),D[E[j]]=I; 
			S=K;
		}
	for(ll i=1;i<=K;i++)C[E[i]]=I;
	for(ll i=h[x],y;i;i=nxt[i])
		if(!d[y=t[i]]&&y!=f)p[r=0]=I,R(y,x,s[y]),Q(r,x);
}
int main(){
	n=read();k=read();
	for(ll i=1,x,y,z;i<n;i++){
		x=read()+1;y=read()+1;z=read();
		add(x,y,z);add(y,x,z);
	}
	for(ll i=1;i<=k;i++)C[i]=D[i]=I;
	p[r=0]=I;R(1,0,n);Q(r,0);
	cout<<(A>=I?-1:A)<<'\n';
	return 0;
}

5. 2013ACM/ICPC亚洲区南京站现场赛 D Tree

这一题就是要求树上乘积为 \(k\) 的链里字典序最小的那一个。细节很多。而且还没有题解对拍,你甚至不知道自己错在哪里。

首先注意到模数 \(10^6+3\) 是一个质数。容易知道 \((10^6+3)^2\) 不在 int 范围内,要开 long long

然后所有点权都小于 \(10^6+3\) 并且不为 \(0\),而且输出要求 \(a<b\),所以链长度不能为零,所以得出一个结论:\(k=0\) 时无解(必须判断,因为输入的时候 \(k\) 可能等于 0)。

所以我们就可以求出 \(1\)\(10^6+2\) 的逆元,然后开始点分治了。

由于这道题是点有权值,所以在计算一个点 \(x\) 的贡献时,统计子树成绩时,不能包含点 \(x\),否则我们计算经过 \(x\) 的链时,\(x\) 的权值会算两次。

要求字典序最小,那么在乘积的同等条件下,两个端点都要尽量小。那么桶一开始要赋值无穷大。

这里用 \(d_i\) 记录了当前节点 \(x\) 当前子树内的一个端点是 \(x\) 的所有经过节点的权值(不包含 \(x\) 的)乘积为 \(i\) 的链的另外一个端点编号最小是多少。\(c_i\) 则是记录当前节点 \(x\) 的之前所有子树的上述值。\(C\) 记录了桶 \(c\) 里不为无穷大的位置。\(D\) 记录了桶 \(d\) 里不为无穷大的位置。

最后记得统计的时候要加上 \(a_x\)

点击查看代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll _=100010,mod=1000003,I=1000000000;
inline ll read(){
    ll x=0;char ch=getchar();
    while(ch>'9'||ch<'0'){if(ch==EOF)exit(0);ch=getchar();}
    while('0'<=ch&&ch<='9'){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
    return x;
}
inline void print(ll x){
    if(x>=10)print(x/10);
    putchar(x%10+48);
}
inline void write(ll x){
    if(x<0)putchar('-'),x=-x;
    print(x);
    putchar(10);
}
inline ll ksm(ll a,ll b){
    ll r=1;
    while(b){if(b&1)r=r*a%mod;a=a*a%mod;b>>=1;}
    return r;
}
ll s[_],p[_],c[mod],d[mod],n,k,r,q[mod],z[_],a[_],C[_],D[_];
pair<ll,ll>ans;
vector<ll>v[_];
inline pair<ll,ll>MIN(pair<ll,ll>x,pair<ll,ll>y){
    if(x.first>x.second)swap(x.first,x.second);
    if(y.first>y.second)swap(y.first,y.second);
    return x<y?x:y;
}
inline void R(ll x,ll f,ll T){
    s[x]=1;p[x]=0;
    for(ll i=0,y;i<v[x].size();i++)
        if(!z[y=v[x][i]]&&y!=f)R(y,x,T),s[x]+=s[y],p[x]=max(p[x],s[y]);
    if((p[x]=max(p[x],T-s[x]))<p[r])r=x;
}
inline void W(ll x,ll f,ll T){
    D[++D[0]]=T;d[T]=min(d[T],x);
    for(ll i=0,y;i<v[x].size();i++)
        if(!z[y=v[x][i]]&&y!=f)W(y,x,T*a[y]%mod);
}
inline void S(ll x,ll f){
    z[x]=1;C[++C[0]]=1;c[1]=x;
    for(ll i=0,y;i<v[x].size();i++)
        if(!z[y=v[x][i]]&&y!=f){
            W(y,x,a[y]);
            for(ll j=1;j<=D[0];j++)
                if(c[k*q[D[j]*a[x]%mod]%mod]<I)ans=MIN(ans,make_pair(d[D[j]],c[k*q[D[j]*a[x]%mod]%mod]));
            for(ll j=1;j<=D[0];j++)C[++C[0]]=D[j],c[D[j]]=min(c[D[j]],d[D[j]]),d[D[j]]=I;
            D[0]=0;
        }
    for(ll i=1;i<=C[0];i++)c[C[i]]=I;
    C[0]=0;
    for(ll i=0,y;i<v[x].size();i++)
        if(!z[y=v[x][i]]&&y!=f)r=0,R(y,x,s[y]),R(r,0,s[y]),S(r,x);
}
int main(){
    for(ll i=1;i<mod;i++)q[i]=ksm(i,mod-2),c[i]=d[i]=I;
    while(1){
        n=read();k=read();
        for(ll i=1;i<=n;i++)a[i]=read(),v[i].clear(),z[i]=0;
        for(ll i=1,x,y;i<n;i++){
            x=read();y=read();
            v[x].push_back(y);
            v[y].push_back(x);
        }
        if(!k){puts("No solution");continue;}
        ans=make_pair(I,I);p[r=0]=I;R(1,0,n);R(r,0,n);S(r,0);
        if(ans==make_pair(I,I))puts("No solution");
        else print(ans.first),putchar(32),print(ans.second),putchar(10);
    }
    return 0;
}
posted @ 2023-04-30 15:39  lrxQwQ  阅读(17)  评论(0编辑  收藏  举报