CDQ分治(三维偏序)——学习笔记
@
前不久学长讲了\(CDQ\)分治,当时只听懂了思想,没去管算法,下来自己看了下题解,发现很眼熟,原来自己过去就写过类似的算法,于是很快就学会了。
模板
模板题
经典的三维偏序:某个物体有三个属性\(x,y,z\),求对于每个\(i∈[1,n]\),存在的\(j∈[1,n],j!=i\)使得\(x_i>=x_j且y_i>=y_j且z_i>=z_j\)成立的个数。
偏序问题的常用解法
一维偏序:直接排序。
二维偏序:排序\(+CDQ\)分治(应该都求过逆序对吧,那就是二维偏序)or 排序+数据结构
三维偏序:排序\(+CDQ\)分治+数据结构
\(k\)维偏序:排序\(+CDQ\)分治嵌套
三维偏序的实现方法:把三个属性存在一个结构体内,以其中一个为关键字对结构体排序,然后对第二维其进行CDQ分治,把第三维用数据结构进行维护,每次计算前一半对后一半的贡献。
\(CDQ\)分治的思想类似于归并排序,写法也很相似,先将区间二分至最小,然后再合并,保证前后两半内部都是有序的,每次计算前一半对后一半的贡献。
想想归并排序求逆序对的时候,我们把原始的下标看做第一位,值看做第二维,我们在归并时,每次放入后一半区间内的数时,都计算一次前面已经放入前一半的数的个数,因为前一半的第一维一定小于后一半的第一维,且先放入的数的第二维也一定小于后放入的数的第二维,因此已经放入前一半的数一定保证第一维和第二维都满足偏序的条件,这就是\(CDQ\)分治的思想。
那么既然二维已经可以求,多了一维该怎么办呢?很简单,只需要多写个树状数组(线段树随意)来维护第三维好了。
这里的树状数组是维护值域,在求二维偏序的基础上,每将一个前一半的数加入时,我们在其第三维的值的地方用树状数组标记\(+1\),每将一个后一半的数加入时,在树状数组中查询小于等于第三维的值的标记的和即可。
如果文字没看懂可以看看代码。
#include<bits/stdc++.h>
using namespace std;
#define ri register int
#define ll long long
const int N=300007;
int n,k,tot,sum[N],to[N],Ans[N],tong[N];
class Node{
public:
int a,b,c,t,cnt,ans;
}node[N],po[N],P[N];
inline int lowbit(int x){
return x&(-x);
}
inline void update(int sta,int x){
while(sta<=k){
sum[sta]+=x;
sta+=lowbit(sta);
}
}
inline int query(int sta){
int res=0;
while(sta>0){
res+=sum[sta];
sta-=lowbit(sta);
}
return res;
}
inline bool cmp(Node x,Node y){
if(x.a==y.a&&x.b==y.b) return x.c<y.c;
if(x.a==y.a) return x.b<y.b;
return x.a<y.a;
}
inline bool check(Node x,Node y){
return (x.a==y.a&&x.b==y.b&&x.c==y.c);
}
inline bool cmp1(Node x,Node y){
return x.t<y.t;
}
inline void cdq(int l,int r){
if(l==r) return;
int mid=(l+r)>>1;
cdq(l,mid);
cdq(mid+1,r);
int i=l,j=mid+1,k=l;
while(i<=mid&&j<=r){
if(po[i].b<=po[j].b){
update(po[i].c,po[i].cnt);
P[k++]=po[i++];
}
else{
po[j].ans+=query(po[j].c);
P[k++]=po[j++];
}
}
while(i<=mid){
update(po[i].c,po[i].cnt);
P[k++]=po[i++];
}
while(j<=r){
po[j].ans+=query(po[j].c);
P[k++]=po[j++];
}
for(ri t=l;t<=mid;++t) update(po[t].c,-po[t].cnt);
for(ri t=l;t<=r;++t) po[t]=P[t];
}
int main()
{
read(n);read(k);
for(ri i=1;i<=n;++i){
read(node[i].a);
read(node[i].b);
read(node[i].c);
node[i].t=i;
}
sort(node+1,node+n+1,cmp);
po[++tot]=node[1];
po[tot].t=tot;
po[tot].cnt=1;
to[1]=1;
for(ri i=2;i<=n;++i){
if(check(po[tot],node[i])) ++po[tot].cnt;
else{
po[++tot]=node[i];
po[tot].t=tot;
po[tot].cnt=1;
}
to[i]=tot;
}
cdq(1,tot);
sort(po+1,po+tot+1,cmp1);
for(ri i=1;i<=n;++i) Ans[node[i].t]=po[to[i]].ans+po[to[i]].cnt-1;
for(ri i=1;i<=n;++i) ++tong[Ans[i]];
for(ri i=0;i<n;++i) printf("%d\n",tong[i]);
return 0;
}
例题
T1
可以发现,对于每个询问的回答有3个限制:时间,横坐标,纵坐标。因此可以用3维偏序的方式处理。
时间作为第一维可以免去排序的步骤,然后根据二维前缀和类似的思想,把一个矩形询问拆分成4个类似前缀和的加减。
如我们要求的矩形左下角的坐标是\((x_1,y_1)\),右上角是\((x_2,y_2)\),那么这个矩形内的人数就是$$f[x_2][y_2]-f[x_2][y_1-1]-f[x_1-1][y_2]+f[x_1-1][y_1-1]$$
\(f[x][y]\)表示左下角坐标为\((0,0)\),右上角坐标为\((x,y)\)的矩形内的人数个数
很简单,对吧
参考代码:
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=2000007,M=200007;
int n,tot,qnum,Que[M][5],sum[N];
struct Node{
int a,b,c,cnt,ans;
}node[M],po[M];
template<class T>inline void read(T &res){
static char ch;T flag=1;
while((ch=getchar())<'0'||ch>'9') if(ch=='-') flag=-1;
res=ch-48;
while((ch=getchar())>='0'&&ch<='9') res=(res<<1)+(res<<3)+ch-48;
res*=flag;
}
inline int lowbit(int x){
return x&(-x);
}
inline void update(int sta,int x){
while(sta<=n){
sum[sta]+=x;
sta+=lowbit(sta);
}
}
inline int query(int sta){
int res=0;
while(sta){
res+=sum[sta];
sta-=lowbit(sta);
}
return res;
}
void cdq(int l,int r){
if(l==r) return;
int mid=(l+r)>>1;
cdq(l,mid);
cdq(mid+1,r);
int i=l,j=mid+1,k=l;
while(i<=mid&&j<=r){
if(node[i].b<=node[j].b){
update(node[i].c,node[i].cnt);
po[k++]=node[i++];
}
else{
node[j].ans+=query(node[j].c);
po[k++]=node[j++];
}
}
while(i<=mid){
update(node[i].c,node[i].cnt);
po[k++]=node[i++];
}
while(j<=r){
node[j].ans+=query(node[j].c);
po[k++]=node[j++];
}
for(int t=l;t<=mid;++t) update(node[t].c,-node[t].cnt);
for(int t=l;t<=r;++t) node[t]=po[t];
}
inline bool cmp(Node x,Node y){
return x.a<y.a;
}
signed main()
{
int opt,x,y,x1,y1,z;
while(1){
read(opt);
if(opt==0) read(n),++n;
else if(opt==1){
read(x);read(y);read(z);
++x;++y;
++tot;
node[tot].a=tot;
node[tot].b=x;
node[tot].c=y;
node[tot].cnt=z;
}
else if(opt==2){
read(x);read(y);
read(x1);read(y1);
++x;++y;++x1;++y1;
Que[++qnum][1]=++tot;
node[tot].a=tot;
node[tot].b=x-1;
node[tot].c=y-1;
node[tot].cnt=0;
Que[qnum][2]=++tot;
node[tot].a=tot;
node[tot].b=x1;
node[tot].c=y1;
node[tot].cnt=0;
Que[qnum][3]=++tot;
node[tot].a=tot;
node[tot].b=x1;
node[tot].c=y-1;
node[tot].cnt=0;
Que[qnum][4]=++tot;
node[tot].a=tot;
node[tot].b=x-1;
node[tot].c=y1;
node[tot].cnt=0;
}
else if(opt==3) break;
}
cdq(1,tot);
sort(node+1,node+tot+1,cmp);
for(int i=1;i<=qnum;++i) printf("%lld\n",node[Que[i][1]].ans+node[Que[i][2]].ans-node[Que[i][3]].ans-node[Que[i][4]].ans);
return 0;
}
T2
还是同样的道理,但这道题显然更灵活,细节也更多。
要求\(dist(A,B)=|A_x-B_x|+|A_y-B_y|\)的最小值,因为绝对值不好处理,因此我们把它拆分成4个方向上的4次询问,单个方向上只用考虑对于一个询问\((A_x,A_y)\),任意满足\(B_x<=A_x,B_y<=A_y的点B中,(B_x+B_y)\)的最大值,再用\((A_x+A_y)-(B_x-B_y)\)得到这个方向上的最小距离。
考虑怎么转化方向?
1.从左下角看:不需要转化。
2.从右上角看:用一个极大的点坐标\((x_{max},y_{max})\)减去每个点的横纵坐标。
3.从左上角看:用\(y_{max}\)减去每个点的纵坐标。
4.从右下角看:用\(x_{max}\)减去每个点的横坐标。
本题数据较大,注意常数和减少不必要的计算。
参考代码:
#include<bits/stdc++.h>
#define ri register int
using namespace std;
const int INF=0x3f3f3f3f;
const int N=1000002;
int n,m,up,ndnum,tot=0,Maxn[N+70],root,Ans[N];
struct Node{
int a,b,c,cnt,ans;
Node(){
ans=-INF;
}
}nd[N],node[N],P[N];
inline int lowbit(int x){
return x&(-x);
}
inline void update(int sta,int x){
while(sta<=N+1){
Maxn[sta]=max(Maxn[sta],x);
sta+=lowbit(sta);
}
}
inline int query(int sta){
int res=-INF;
while(sta){
res=max(res,Maxn[sta]);
sta-=lowbit(sta);
}
return res;
}
inline void clear(int sta){
while(sta<=N+1){
Maxn[sta]=-INF;
sta+=lowbit(sta);
}
}
void cdq(int l,int r){
if(l==r) return;
int mid=(l+r)>>1;
cdq(l,mid);
cdq(mid+1,r);
int i=l,j=mid+1,k=l;
while(i<=mid&&j<=r){
if(node[i].b<=node[j].b){
if(node[i].cnt) update(node[i].c,node[i].b+node[i].c);
P[k++]=node[i++];
}
else{
if(!node[j].cnt) node[j].ans=max(node[j].ans,query(node[j].c));
P[k++]=node[j++];
}
}
while(i<=mid) P[k++]=node[i++];
while(j<=r){
if(!node[j].cnt) node[j].ans=max(node[j].ans,query(node[j].c));
P[k++]=node[j++];
}
for(ri t=l;t<=mid;++t) if(node[t].cnt) clear(node[t].c);
for(ri t=l;t<=r;++t) node[t]=P[t];
}
template<class T>inline void read(T &res){
static char ch;T flag=1;
while((ch=getchar())<'0'||ch>'9') if(ch=='-') flag=-1;
res=ch-48;
while((ch=getchar())>='0'&&ch<='9') res=(res<<1)+(res<<3)+ch-48;
res*=flag;
}
struct ios{
inline char read(){
static const int IN_LEN=1<<18|1;
static char buf[IN_LEN],*s,*t;
return (s==t)&&(t=(s=buf)+fread(buf,1,IN_LEN,stdin)),s==t?-1:*s++;
}
template <typename _Tp> inline ios & operator >> (_Tp&x){
static char c11,boo;
for(c11=read(),boo=0;!isdigit(c11);c11=read()){
if(c11==-1)return *this;
boo|=c11=='-';
}
for(x=0;isdigit(c11);c11=read())x=x*10+(c11^'0');
boo&&(x=-x);
return *this;
}
}io;
template<class T>void Write(T x){
if(x<0)
x=-x,putchar('-');
if(x>9)
Write(x/10);
putchar(x%10+48);
}
int main()
{
io>>n>>m;
for(ri i=1;i<=n;++i){
io>>nd[i].b>>nd[i].c;
++nd[i].b;++nd[i].c;
nd[i].a=++tot;
nd[i].cnt=1;
}
for(ri i=1;i<=m;++i){
int opt;
++tot;
io>>opt>>nd[tot].b>>nd[tot].c;
++nd[tot].b;++nd[tot].c;
nd[tot].a=tot;
nd[tot].cnt=(opt==1);
}
for(ri i=1;i<=N+1;++i) Maxn[i]=-INF;
memset(Ans,INF,sizeof(Ans));
for(ri i=1;i<=tot;++i) node[i]=nd[i];
cdq(1,tot);
for(ri i=1;i<=tot;++i) if(!node[i].cnt) Ans[node[i].a]=min(Ans[node[i].a],node[i].b+node[i].c-node[i].ans);
for(ri i=1;i<=tot;++i){
node[i]=nd[i];
node[i].b=N-node[i].b;
node[i].c=N-node[i].c;
}
cdq(1,tot);
for(ri i=1;i<=tot;++i) if(!node[i].cnt) Ans[node[i].a]=min(Ans[node[i].a],node[i].b+node[i].c-node[i].ans);
for(ri i=1;i<=tot;++i){
node[i]=nd[i];
node[i].c=N-nd[i].c;
}
cdq(1,tot);
for(ri i=1;i<=tot;++i) if(!node[i].cnt) Ans[node[i].a]=min(Ans[node[i].a],node[i].b+node[i].c-node[i].ans);
for(ri i=1;i<=tot;++i){
node[i]=nd[i];
node[i].b=N-node[i].b;
}
cdq(1,tot);
for(ri i=1;i<=tot;++i) if(!node[i].cnt) Ans[node[i].a]=min(Ans[node[i].a],node[i].b+node[i].c-node[i].ans);
for(ri i=1;i<=tot;++i) if(Ans[i]!=INF) Write(Ans[i]),putchar('\n');
return 0;
}