cdq分治学习笔记
前言
感谢$__stdcall$的讲解,感谢伟大的导师$_tham$提供一系列练手题
cdq分治是什么?
国人(陈丹琦)引进的算法,不同于一般的分治,我们常说的分治是将问题分成互不影响的几个区间,递归进行处理,而所谓$cdq$分治,在处理一个区间时,还要计算它对其他区间的贡献。
二维偏序问题
给定$n$个二元组$[a,b]$,$m$次询问,每次给定其中的一个二元组$[c,d]$,求满足条件$c<a&d<b$的二元组的个数
不知道怎么做?逆序对你总会求吧?逆序对就是一种经典的二维偏序问题,我们不妨这样转换逆序对问题:
给定$n$个数,定义一个二元组为$[$元素下标,元素值$]$,则共有$n$个这样的二元组
我们只需将约束条件改为:$c<a&d>b$就行了。
那么,解决二维偏序的一般模式,也只需要改一下合并时的那一句话就好了。
PS:啊?你忘了怎么用归并排序求逆序对?戳我
相同的,我们也可以用树状数组来求解。复杂度同样为$O(nlogn)$
既然我们能用树状数组来解决用$cdq$分治的题,那我们能不能用$cdq$分治来解决树状数组的题目呢?当然可以,比如这道:Luogu3374 树状数组1
给定一个$n$个元素的序列$a$,初始值全部为$0$,对这个序列进行以下两种操作
操作$1$:格式为$1\ x\ k$,把所有位置$x$的元素加上$k$
操作$2$:格式为$2 x y$,求出区间$[x,y]$内所有元素的和。
这显然是一道树状数组模板题,考虑如何用$cdq$分治来解决它。
我们不妨以修改的时间为第一关键字,修改元素的位置为第二关键字。由于时间已经有序,我们定义结构体包含$3$个元素:$opt,ind,val$,其中$ind$表示操作的位置,$opt$为$1$表示修改,$val$表示“加上的值”。而对于查询,我们用前缀和的思想把他分解成两个操作:$sum[1,y]-sum[1,x-1]$,即分解成两次前缀和的查询。在合并的过程中,$opt$为$2$表示遇到了一个查询的左端点$x-1$,对结果作负贡献,$opt$为$3$表示遇到了一个查询的右端点$y$,对结果作正贡献,$val$表示“是第几个查询”。这样,我们就把每个操作转换成了带有附加信息的有序对(时间,位置),然后对整个序列进行$cdq$分治。
#include <cstdio>
#include <cstring>
#include <algorithm>
using std::min;
using std::max;
using std::swap;
using std::sort;
typedef long long ll;
const int N = 5e5 + 10, M = 5e5 + 10;
int n, m, aid, qid;
ll ans[M];
struct Query {
int ind, opt; ll val;
inline bool operator < (const Query a) const {
return ind == a.ind ? opt < a.opt : ind < a.ind;
}
}q[(M << 1) + N], tmp[(M << 1) + N];
inline void cdq (int l, int r) {
if (l == r) return ;
int mid = (l + r) >> 1;
cdq(l, mid), cdq(mid + 1, r);
int i = l, j = mid + 1, p = l; ll sum = 0;
while (i <= mid && j <= r)
if (q[i] < q[j]) {
if (q[i].opt == 1) sum += q[i].val;
tmp[p++] = q[i++];
} else {
if (q[j].opt == 2) ans[q[j].val] -= sum;
if (q[j].opt == 3) ans[q[j].val] += sum;
tmp[p++] = q[j++];
}
while (i <= mid) { if (q[i].opt == 1) sum += q[i].val; tmp[p++] = q[i++]; }
while (j <= r) {
if (q[j].opt == 2) ans[q[j].val] -= sum;
if (q[j].opt == 3) ans[q[j].val] += sum;
tmp[p++] = q[j++];
}
for (int k = l; k <= r; ++k) q[k] = tmp[k];
}
int main () {
scanf ("%d%d", &n, &m);
for (int i = 1; i <= n; ++i) {
q[++qid].ind = i, q[qid].opt = 1;
scanf("%lld", &q[qid].val);
}
int opt, ind, l, r; ll val;
for (int i = 1; i <= m; ++i) {
scanf("%d", &opt);
if (opt == 1) scanf("%d%lld", &ind, &val), q[++qid] = (Query){ind, 1, val};
else {
scanf ("%d%d", &l, &r);
q[++qid] = (Query){l - 1, 2, ++aid}, q[++qid] = (Query){r, 3, aid};
}
}
cdq(1, qid);
for (int i = 1; i <= aid; ++i)
printf("%lld\n", ans[i]);
return 0;
}
三维偏序问题
给定$n$个三元组$[a,b,c]$,$m$次询问,每次给定其中的一个二元组$[d,e,f]$,求满足条件$d<a&e<b&f<c$的二元组的个数
相同的,我们也可以采取用其他方法来解决三位偏序问题,如$bitset$、$KD\ Tree$、树套树等...比如我们可以以$a$为关键字排序,同时用$BIT$套平衡树来维护剩下的两个元素。
接着考虑如何用$cdq$分治来解决这个问题,我们可以考虑先以$a$为关键字对数组排序,这样我们的问题就成了维护后两个元素了。接下来,我们以一个经典的三维偏序题:陌上花开来做具体说明(由于这道题较为经典,在各大$OJ$都能找到,不给出链接)
题面
有n朵花,每朵花有三个属性:花形(s)、颜色(c)、气味(m),由三个整数表示。现要对每朵花评级,一朵花的级别是它拥有的美丽能超过的花的数量。定义一朵花A比花B要美丽,当且仅Sa>=Sb,Ca>=Cb,Ma>=Mb。显然,两朵花可能有同样的属性。需要统计出评出每个等级的花的数量。
题解
- 就如刚才所说的,以$a$为关键字进行排序
struct Node {
int a, b, c, mult, ans;
inline void Init() {
read(a), read(b), read(c);
}
} v[N], d[N];
inline bool cmpx (Node x, Node y) {
return (x.a < y.a) || (x.a == y.a && x.b < y.b) || (x.a == y.a && x.b == y.b && x.c < y.c);
}
int main () {
read(n), read(k);
for (int i = 1; i <= n; ++i) v[i].Init();
sort(&v[1], &v[n + 1], cmpx);
}
- 然后,我们会发现,普通的三位偏序只用处理小于,而不是小于等于,根据题意,完全相同属性的花是不计算在内的,所以我们得考虑将其去重。
for (int i = 1; i <= n; ++i) {
++mul;//相同元素的个数
//这里的异或你可以理解为不等于,由于之前已经排过序(见函数cmpx),可以线性比较,mult表示重复元素的个数
if ((v[i].a ^ v[i + 1].a) || (v[i].b ^ v[i + 1].b) || (v[i].c ^ v[i + 1].c))
d[++m] = v[i], d[m].mult = mul, mul = 0;
}
- 接着,我们考虑如何进行$cdq$分治,同样是在计算左区间时,处理右区间的询问,不妨采用$two-pointers$,两个指针$i,j$分别指向左右两个区间,这时候我们以$b$为关键字进行比较,如果$d[i].b<=d[j].b$,则将$d[i].c$插入权值$BIT$中,反之则在$BIT$中查询比$d[j].c$小的数的个数,作正贡献。在两个区间都扫完后,我们要考虑清空$BIT$,防止在接下来的递归回溯中被添加多次。
inline bool cmpy (Node x, Node y) {
return (x.b < y.b) || (x.b == y.b && x.c < y.c);
}
inline void cdq (int l, int r) {
if (l == r) return ;
int mid = (l + r) >> 1;
cdq(l, mid), cdq(mid + 1, r);
int i = l;
for (int j = mid + 1; j <= r; ++j) {
while (d[i].b <= d[j].b && i <= mid) update(d[i].c, d[i].mult), ++i;
d[j].ans += query(d[j].c);
//ans表示小于等于它的个数
}
//清空BIT
for (int k = l; k < i; ++k)
update(d[k].c, -d[k].mult);
inplace_merge(&d[l], &d[mid + 1], &d[r + 1], cmpy);
//这个函数表示将区间[l,mid+1)和[mid+1,r+1)按照cmpy方法合并
}
- 计算答案。
for (int i = 1; i <= m; ++i) ans[d[i].ans + d[i].mult - 1] += d[i].mult;
for (int i = 0; i < n; ++i) printf("%d\n", ans[i]);
代码
#include <cstdio>
#include <algorithm>
using std::sort;
using std::inplace_merge;
typedef long long ll;
template<typename T>
inline void read (T &x) {
char ch = getchar(); int flag = 1;
while(ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
if (ch == '-') flag = -flag, ch = getchar();
while(ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar();
x *= flag;
}
const int N = 1e5 + 10, K = 2e5 + 10;
int n, m, k, mul, ans[N], bit[K];
struct Node {
int a, b, c, mult, ans;
inline void Init() {
read(a), read(b), read(c);
}
} v[N], d[N];
inline bool cmpx (Node x, Node y) {
return (x.a < y.a) || (x.a == y.a && x.b < y.b) || (x.a == y.a && x.b == y.b && x.c < y.c);
}
inline bool cmpy (Node x, Node y) {
return (x.b < y.b) || (x.b == y.b && x.c < y.c);
}
inline int lowbit (int x) { return x & (-x); }
inline void update (int pos, int val) {
while (pos <= k) bit[pos] += val, pos += lowbit(pos);
}
inline int query (int pos) {
int val = 0;
while (pos) val += bit[pos], pos -= lowbit(pos);
return val;
}
inline void cdq (int l, int r) {
if (l == r) return ;
int mid = (l + r) >> 1;
cdq(l, mid), cdq(mid + 1, r);
int i = l;
for (int j = mid + 1; j <= r; ++j) {
while (d[i].b <= d[j].b && i <= mid) update(d[i].c, d[i].mult), ++i;
d[j].ans += query(d[j].c);
}
for (int k = l; k < i; ++k)
update(d[k].c, -d[k].mult);
inplace_merge(&d[l], &d[mid + 1], &d[r + 1], cmpy);
}
int main () {
read(n), read(k);
for (int i = 1; i <= n; ++i) v[i].Init();
sort(&v[1], &v[n + 1], cmpx);
for (int i = 1; i <= n; ++i) {
++mul;
if ((v[i].a ^ v[i + 1].a) || (v[i].b ^ v[i + 1].b) || (v[i].c ^ v[i + 1].c))
d[++m] = v[i], d[m].mult = mul, mul = 0;
}
cdq(1, m);
for (int i = 1; i <= m; ++i) ans[d[i].ans + d[i].mult - 1] += d[i].mult;
for (int i = 0; i < n; ++i) printf("%d\n", ans[i]);
return 0;
}