浅谈cdq分治
寒假就听很多大佬说过cdq分治,最近正好学到,写个博客总结一下。
什么是分治?
所谓分治就是把一个大问题分解成两个小问题,解决完两个小问题之后再考虑两个小问题之间的影响(或者先考虑两个小问题之间的影响=-=)。
比如点分治求树上两点之间距离为k的点对有多少个,就是找出重心后先求出过重心的点对数,再去递归求解子树。又比如归并排序求逆序对时就是先递归算出左半区间和右半区间对答案的贡献再考虑左区间大于右区间的数的个数。
那什么是cdq分治呢?
cdq分治来源于09年陈丹琦大佬的国家队论文。
我们先来看一下一个cdq分治的经典问题:二维偏序
- 给你n个二元组(x,y),问有多少对二元组满足xi<xj并且yi<yj。(n<=50000)
我们考虑上面我们谈到过的逆序对问题,其实偏序问题就是正序对问题,我们按x坐标按升序排序,然后求y坐标的正序对即可。
一个经典应用:单点修改,区间查询。
这是一道树状数组经典题,不过我们今天要用cdq分治解决。
我们可以把操作的位置看做一维,操作的时间看做另一维。那么当(xi<xj)且(ti<tj)时,i操作就会对j询问造成影响。所以我们可以用分治解决。但是询问操作问的是区间,不是很好处理。其实不要紧,我们把对于区间[l,r]的询问看做[1,r]-[1,l-1]就可以了。初值也不太好处理,我们把初值当做修改来做。具体做法看代码体会吧qaq
1 // cdq分治 2 #include<iostream> 3 #include<cstdio> 4 #include<cstring> 5 #include<algorithm> 6 #define LL long long 7 #define RI register int 8 using namespace std; 9 const int INF = 0x7ffffff ; 10 const int N = 500000 + 10 ; 11 12 inline int read() { 13 int k = 0 , f = 1 ; char c = getchar() ; 14 for( ; !isdigit(c) ; c = getchar()) 15 if(c == '-') f = -1 ; 16 for( ; isdigit(c) ; c = getchar()) 17 k = k*10 + c-'0' ; 18 return k*f ; 19 } 20 struct Query { 21 int x, id, key, kind, bl ; // 数组下标;时间;键值(修改操作为修改值,询问操作记录是减还是加) 22 // 操作类型;只有询问操作记录bl,表示该询问是第几个询问。 23 }q[N*3], tmp[N*3] ; 24 int n, m ; int ans[N] ; 25 26 void solve(int l,int r) { 27 if(l == r) return ; 28 int ll = l, mid = l+r>>1, rr = mid+1, sum = 0 ; 29 for(int i=l;i<=r;i++) { 30 if(q[i].id <= mid && q[i].kind == 1) sum += q[i].key ; 31 if(q[i].id > mid && q[i].kind == 2) ans[q[i].bl] += q[i].key*sum ; 32 } 33 for(int i=l;i<=r;i++) { 34 if(q[i].id <= mid) tmp[ll++] = q[i] ; 35 else tmp[rr++] = q[i] ; 36 } 37 for(int i=l;i<=r;i++) q[i] = tmp[i] ; 38 solve(l,mid) ; solve(mid+1,r) ; 39 } 40 41 inline bool cmp1(Query s,Query t) { return s.x == t.x ? s.kind < t.kind : s.x < t.x ; } 42 // 优先按照数组下标排序,当数组下标相同时,修改置于询问之前。 43 int main() { 44 n = read(), m = read() ; int tot = 0, t = 0 ; 45 for(int i=1;i<=n;i++) { 46 q[++tot].x = i, q[tot].id = tot, q[tot].key = read(), q[tot].kind = 1 ; 47 } 48 int ii ; 49 for(int i=1;i<=m;i++) { 50 ii = read() ; 51 if(ii == 1) { 52 q[++tot].x = read(), q[tot].id = tot, q[tot].key = read(), q[tot].kind = 1 ; 53 } else { 54 q[++tot].x = read()-1, q[tot].id = tot, q[tot].key = -1, q[tot].kind = 2, q[tot].bl = ++t ; 55 q[++tot].x = read(), q[tot].id = tot, q[tot].key = 1, q[tot].kind = 2, q[tot].bl = t ; 56 } 57 } 58 sort(q+1,q+tot+1,cmp1) ; solve(1,tot) ; 59 for(int i=1;i<=t;i++) printf("%d\n",ans[i]) ; 60 return 0 ; 61 }
(代码题目) (常数较大,需开O2)
三维偏序:给定n个三元组(x,y,z),问满足xi<xj,yi<yj,zi<zj的点对有多少个。(n<=50000)
我们还是按照上面的套路,把他们按照x升序排序。然后按照y值二分。那z值怎么办呢?我们可以用权值树状数组记录z值。
还是借助题目和代码来理解:bzoj 2683 简单题
1 #include<iostream> 2 #include<cstdio> 3 #include<cstring> 4 #include<algorithm> 5 #define LL long long 6 #define RI register int 7 using namespace std; 8 const int INF = 0x7ffffff ; 9 const int N = 500000 + 10 ; 10 const int M = 200000 + 10 ; 11 12 inline int read() { 13 int k = 0 , f = 1 ; char c = getchar() ; 14 for( ; !isdigit(c) ; c = getchar()) 15 if(c == '-') f = -1 ; 16 for( ; isdigit(c) ; c = getchar()) 17 k = k*10 + c-'0' ; 18 return k*f ; 19 } 20 struct Query { 21 int x, y, key, id, kind, bl ; 22 }q[M<<2], tmp[M<<2] ; 23 int n ; int tr[N], ans[M] ; 24 inline int lowbit(int x) { return x&(-x) ; } 25 inline void add(int x,int k) { 26 while(x <= n) { 27 tr[x] += k ; x += lowbit(x) ; 28 } 29 } 30 inline int sum(int x) { 31 int res = 0 ; 32 while(x) { 33 res += tr[x] ; 34 x -= lowbit(x) ; 35 } 36 return res ; 37 } 38 39 void solve(int l,int r) { 40 if(l == r) return ; 41 int ll = l, mid = (l+r)>>1, rr = mid+1 ; 42 for(int i=l;i<=r;i++) { 43 if(q[i].id <= mid && q[i].kind == 1) add(q[i].y,q[i].key) ; 44 if(q[i].id > mid && q[i].kind == 2) ans[q[i].bl] += q[i].key*sum(q[i].y) ; 45 } 46 for(int i=l;i<=r;i++) { 47 if(q[i].id <= mid && q[i].kind == 1) add(q[i].y,-q[i].key) ; 48 } 49 for(int i=l;i<=r;i++) 50 if(q[i].id <= mid) tmp[ll++] = q[i] ; 51 else tmp[rr++] = q[i] ; 52 for(int i=l;i<=r;i++) q[i] = tmp[i] ; 53 solve(l,mid) ; solve(mid+1,r) ; 54 } 55 56 inline bool cmp1(Query s,Query t) { return s.x == t.x ? s.kind < t.kind : s.x < t.x ; } 57 int main() { 58 n = read() ; int tot = 0, t = 0 ; 59 while(1) { 60 int ii = read() ; 61 if(ii == 3) break ; 62 if(ii == 1) { 63 int x = read(), y = read(), z = read() ; 64 q[++tot] = (Query){x, y, z, tot, 1, 0 } ; 65 } else { 66 int x1 = read(), y1 = read(), x2 = read(), y2 = read() ; 67 q[++tot] = (Query){x2, y2, 1, tot, 2, ++t } ; 68 q[++tot] = (Query){x1-1, y1-1, 1, tot, 2, t } ; 69 q[++tot] = (Query){x2, y1-1, -1, tot, 2, t } ; 70 q[++tot] = (Query){x1-1, y2, -1, tot, 2, t } ; 71 } 72 } 73 sort(q+1,q+tot+1,cmp1) ; 74 solve(1,tot) ; 75 for(int i=1;i<=t;i++) printf("%d\n",ans[i]) ; 76 return 0 ; 77 }
(特别感谢TRTTG和mlystdcall大佬,我就是看着他们的博客学习的)