分治专题
1. 根号分治
1.1. 算法简介
根号分治,就是在预处理与询问的复杂度之间寻找平衡的一个算法。通常以根号作为问题规模的分界线,规模小于根号的询问可以 \(n\sqrt n\) 预处理求出,而回答一次规模为 \(B\geq n\) 的询问的时间只需要 \(\dfrac n B\leq \sqrt n\),那么整个题目就可以做到 \(n\sqrt n\)。
1.2. 例题
I. CF797E Array Queries
题意简述:给出 \(\{a_i\}\),多次询问给出 \(p,k\),求至少执行多少次 \(p\gets p+a_p+k\) 才能使 \(p>n\)。
注意到如果 \(k>\sqrt n\) 那么答案必定不大于 \(\sqrt n\),那么对于所有位置预处理出所有 \(k\leq \sqrt n\) 的答案,若 \(k>\sqrt n\) 直接暴力查询即可。时间复杂度 \(\mathcal{O}(n\sqrt n)\)。
/*
Powered by C++11.
Author : Alex_Wei.
*/
#include <bits/stdc++.h>
using namespace std;
const int N=1e5+5;
const int B=333;
int n,m,b,a[N],s[N][B];
int main(){
cin>>n,b=sqrt(n);
for(int i=1;i<=n;i++)cin>>a[i];
for(int i=1;i<=b;i++)for(int j=n;j;j--)
s[j][i]=j+a[j]+i>n?1:(s[j+a[j]+i][i]+1);
cin>>m;
for(int i=1;i<=m;i++){
int p,k; cin>>p>>k;
if(k<=b)cout<<s[p][k]<<endl;
else{
int ans=0;
while(p<=n)ans++,p+=a[p]+k;
cout<<ans<<endl;
}
}
return 0;
}
II. *CF1039D You Are Given a Tree
题意简述:给出一棵树,对每个 \(k\in[1,n]\),求出最多能找出多少条没有公共点的至少经过 \(k\) 个点的链。
注意到若 \(k>\sqrt n\) 则答案一定不大于 \(\sqrt n\)(怎么和上一题那么像,笑)。那么对于 \(1\leq k\leq \sqrt n\),直接暴力树形 DP。然后再枚举 \(1\leq ans\leq \sqrt n\),不过这次枚举的是链的条数,即答案。显然答案单调不升,于是二分出答案为 \(ans\) 的 \(k\) 的区间即可(实际上不需要右端点,只需要左端点)。
树形 DP 求链上经过的点 \(k\) 的答案:该部分比较类似 赛道修建,不过也有一些区别:因为一个点只能被一条链经过(而不是赛道修建中的一条边),于是分两种情况讨论:记 \(mx_1,mx_2\) 为 \(i\) 的儿子所传入的最长的两条链(所经过点的个数),若 \(mx_1+mx_2+1\geq k\),那么显然是将 \(i\) 与它的两个儿子配成一条链,答案加 \(1\);否则将 \(mx+1\) 传上去到 \(fa_i\) 即可。这样一次 DP 就是 \(\mathcal{O}(n)\) 的。
因此,总时间复杂度为 \(\mathcal{O}(n\sqrt n\log n)\)。
卡常技巧:将每个节点的父亲预处理出来,然后按照 dfs 序排序,可以直接循环树形 DP,不需要 dfs。
/*
Powered by C++11.
Author : Alex_Wei.
*/
//#pragma GCC optimize(3)
#include <bits/stdc++.h>
using namespace std;
#define mem(x,v) memset(x,v,sizeof(x))
namespace IO{
char buf[1<<21],*p1=buf,*p2=buf,obuf[1<<23],*O=obuf;
#ifdef __WIN32
#define gc getchar()
#else
#define gc (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<20,stdin),p1==p2)?EOF:*p1++)
#endif
#define pc(x) (*O++=x)
#define flush() fwrite(obuf,O-obuf,1,stdout)
inline int read(){
int x=0,sign=0; char s=gc;
while(!isdigit(s))sign|=s=='-',s=gc;
while(isdigit(s))x=(x<<1)+(x<<3)+(s-'0'),s=gc;
return sign?-x:x;
}
void print(ll x) {if(x>9)print(x/10); pc(x%10+'0');}
}
using namespace IO;
const int N=1e5+5;
int ed,ed2,hd[N],nxt[N<<1],to[N<<1];
pii nw[N];
void add(int u,int v){
nxt[++ed]=hd[u],hd[u]=ed,to[ed]=v;
}
int n,p,cnt,ans[N];
void dfs0(int id,int f){
for(int i=hd[id];i;i=nxt[i]){
if(to[i]==f)continue;
nw[++ed2]={id,to[i]},dfs0(to[i],id);
}
}
int dfs(int id){
int mx=0,mx2=0;
for(int i=hd[id];i;i=nxt[i]){
int v=dfs(to[i]);
if(v>=p){cnt++; return 0;}
if(v>=mx)mx2=mx,mx=v;
else if(v>=mx2)mx2=v;
} if(mx+mx2+1>=p){cnt++; return 0;}
return mx+1;
} int run(int x){
cnt=0,p=x,dfs(1);
return cnt;
}
int main(){
cin>>n;
for(int i=1;i<n;i++){
int a=read(),b=read();
add(a,b),add(b,a);
} int m=sqrt(n*log2(n));
dfs0(1,0),mem(hd,0),mem(nxt,0),ed=0;
for(int i=1;i<n;i++)add(nw[i].fi,nw[i].se);
for(int i=1;i<=m;i++)ans[i]=run(i);
for(int i=1,pre=n+1;i<=n/m+1;i++){
int l=1,r=pre;
while(l<r){
int m=(l+r>>1)+1;
if(run(m)>=i)l=m;
else r=m-1;
} for(int j=l+1;j<pre;j++)ans[j]=i-1; pre=l+1;
} for(int i=1;i<=n;i++)cout<<ans[i]<<endl;
return flush(),0;
}
2. cdq 分治
人类智慧算法。
在一个序列中,需要计算满足某些限制的点对 \((i,j)\)(这里 \(i,j\) 都表示位置)对答案的贡献,通常这些点对都有 \(\mathcal{O}(n^2)\) 个。cdq 分治的核心思想就是将所有需要计算贡献的点对 \((i,j)\) 分成三类:第一类 \(i,j\in[1,mid]\);第二类 \(i,j\in(mid,n]\);第三类 \(i\in[1,mid],j\in(mid,n]\)。这样一来就可以先递归处理第一、二类点对的答案,再运用一些方法快速求第三类的答案。
2.1. 例题
I. P3810 【模板】三维偏序(陌上花开)
题意简述:对每个 \(d\),求使 \(a_j\leq a_i,b_j\leq b_i,c_j\leq c_i,i\neq j\) 的 \(j\) 的个数有 \(d\) 个的 \(i\) 的个数。
首先去重,cdq 一般解决不了有重复元素的问题,除非重复元素之间不算贡献。接着将所有点按 \(a_i,b_i,c_i\) 分别为第一、二、三关键字从小到大排序。
这样做,排除了 \(a_i\) 对答案的限制。因为右区间的任何一个点都不会对左区间中的任何一个点有贡献。这样一来,需要求的就变成了对右区间的每个点 \(i\),求左区间的所有点 \(j\) 中,满足 \(b_j\leq b_i,c_j\leq c_i\) 的 \(j\) 有多少个。
先将区间内部的点按照 \(b_i,c_i\) 分别为第一、二关键字从小到大排序,那么对于每个 \(i\),可能符合条件(\(b_j\leq b_i\))的 \(j\) 一定是左区间的一段随着 \(i\) 的增大单调不缩的前缀。对于一段前缀,求有多少个 \(j\) 满足 \(c_j\leq c_i\) 就是树状数组的拿手好戏了。
视值域与序列大小同阶(离散化一下即可),则时间复杂度为 \(\mathcal{O}(n\log^2 n)\)。
一些注意点:
- 树状数组在添加 / 删除时权值为点的个数而不是 \(1\)。
- 别忘了考虑重复元素之间的贡献,即最终答案还要加上该重复元素个数 \(-1\)。
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+5;
int n,m,k,f[N];
struct pt{
int a,b,c,ans,cnt;
}d[N],u[N];
bool cmp1(pt a,pt b){return a.a!=b.a?a.a<b.a:a.b!=b.b?a.b<b.b:a.c<b.c;}
bool cmp2(pt a,pt b){return a.b!=b.b?a.b<b.b:a.c<b.c;}
int c[N<<1];
void add(int x,int v){while(x<=k)c[x]+=v,x+=x&-x;}
int query(int x){int s=0; while(x)s+=c[x],x-=x&-x; return s;}
void solve(int l,int r){
if(l==r)return;
int m=l+r>>1,le=l;
solve(l,m),solve(m+1,r);
sort(u+l,u+m+1,cmp2),sort(u+m+1,u+r+1,cmp2);
for(int i=m+1;i<=r;i++){
while(le<=m&&u[le].b<=u[i].b)add(u[le].c,u[le].cnt),le++;
u[i].ans+=query(u[i].c);
} for(int i=l;i<le;i++)add(u[i].c,-u[i].cnt);
}
int main(){
cin>>n>>k;
for(int i=1;i<=n;i++)d[i].a=read(),d[i].b=read(),d[i].c=read();
sort(d+1,d+n+1,cmp1);
for(int i=1;i<=n;i++){
if(d[i].a!=d[i-1].a||d[i].b!=d[i-1].b||d[i].c!=d[i-1].c)u[++m]=d[i];
u[m].cnt++;
} solve(1,m);
for(int i=1;i<=m;i++)f[u[i].ans+u[i].cnt-1]+=u[i].cnt;
for(int i=1;i<=n;i++)print(f[i-1]),pc('\n');
return flush(),0;
}
II. P4755 Beautiful Pair
首先对其进行 cdq 分治,设当前区间为 \([l,r]\),\(m=\frac{l+r}{2}\)。
对于每个位置 \(i\),若 \(i\in[l,m]\) 则记 \(suf_i=\max_{j=i}^m a_j\),若 \(i\in[m+1,r]\) 则记 \(pre_i=\max_{j=m+1}^i a_j\)。
分别考虑最大值在 \([l,m]\) 之间与在 \([m+1,r]\) 之间的情况:若最大值在左侧,则枚举 \(i\in[l,m]\),找到右侧的分界点 \(p\) 使得对于 \(j\in[m+1,p]\) 都有 \(pre_j\leq suf_i\),那么查询 \([m+1,p]\) 有多少个 \(j\) 使得 $a_j\leq \frac{suf_i}{a_i} $(不等号右边是定值),这个可以用主席树或者 BIT 做到。反之同理。
别忘了离散化。注意最大值在右边时要找分界点 \(p\) 使得对于 \(j\in[p,m]\) 都有 \(suf_j<pre_i\),而不是 \(\leq\),因为后者会多加上最大值在两边都出现的情况,而这种情况在考虑最大值在左边时已经计算过。
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N=1e5+5;
ll n,ans,c,a[N],b[N],d[N],suf[N],pre[N];
ll node,R[N],ls[N<<5],rs[N<<5],val[N<<5];
void modify(int pre,ll &x,int l,int r,int p){
val[x=++node]=val[pre]+1,ls[x]=ls[pre],rs[x]=rs[pre];
if(l==r)return;
int m=l+r>>1;
if(p<=m)modify(ls[pre],ls[x],l,m,p);
else modify(rs[pre],rs[x],m+1,r,p);
} ll query(int l,int r,int p,int x,int y){
if(l==r)return val[y]-val[x];
int m=l+r>>1;
if(p<=m)return query(l,m,p,ls[x],ls[y]);
return val[ls[y]]-val[ls[x]]+query(m+1,r,p,rs[x],rs[y]);
}
ll solve(int l,int r){
if(l==r)return b[a[l]]==1;
ll m=l+r>>1,ans=solve(l,m)+solve(m+1,r);
for(int i=m+1;i<=r;i++)pre[i]=max(pre[i-1],a[i]);
for(int i=m;i>=l;i--)d[m-i+1]=suf[i]=max(suf[i+1],a[i]);
for(int i=l;i<=m;i++){
int p=upper_bound(pre+m+1,pre+r+1,suf[i])-pre-1;
if(p>m){
int nd=upper_bound(b+1,b+c+1,b[suf[i]]/b[a[i]])-b-1;
if(nd)ans+=query(1,c,nd,R[m],R[p]);
}
} for(int i=m+1;i<=r;i++){
int p=m+1-(lower_bound(d+1,d+m-l+2,pre[i])-d-1);
if(p<=m){
int nd=upper_bound(b+1,b+c+1,b[pre[i]]/b[a[i]])-b-1;
if(nd)ans+=query(1,c,nd,R[p-1],R[m]);
}
}
for(int i=l;i<=r;i++)pre[i]=suf[i]=0;
return ans;
}
int main(){
cin>>n;
for(int i=1;i<=n;i++)cin>>a[i],b[i]=a[i];
sort(b+1,b+n+1),c=unique(b+1,b+n+1)-b-1;
for(int i=1;i<=n;i++){
a[i]=lower_bound(b+1,b+c+1,a[i])-b;
modify(R[i-1],R[i],1,c,a[i]);
}
cout<<solve(1,n)<<endl;
return 0;
}