「CDQ分治」P3810 【模板】三维偏序(陌上花开)
知识点: CDQ 分治
原题面 Luogu
重学 CDQ /fad
假期里光想着摸鱼摸鱼了,抄完题解就跑路了= =
简述
给定 \(n\) 个元素,第 \(i\) 个元素有 \(a_i,b_i,c_i\) 三个属性。
对于第 \(i\) 个元素,设 \(f(i)\) 表示满足 \(a_j\le a_i, b_j\le b_i, c_j\le c_i\) 且 \(i\not= j\) 的 \(j\) 的数量。
对于 \(d\in [0,n)\),求 \(f(i)=d\) 的数量。
\(1\le n\le 10^5, 1\le a_i,b_i,c_i\le 2\times 10^5\)。
分析
Cdq 分治是一种牛逼思想,一般用于解决区间点对问题。设当前处理的区间为 \([l, r]\),Cdq 分治的一般过程:
- 若 \(l = r\),返回。
- 设区间中点为 \(mid\),递归处理 \([l,mid]\) 和 \([mid + 1, r]\)。
- 不横跨 \(mid\) 的点对都会在递归中被解决,仅考虑横跨 \(mid\) 的点对的贡献。
个人理解:
将 \(O(n^2)\) 级别的点对统计,利用单调性确定合法性,变为 \(O(n\log^k n)\) 级别。
太神了!
回到此题。
三个属性的元素的无序集合比较烦,将其转为有序序列,以 \(a_i\) 的相对大小为下标。考虑套一个 Cdq 上去。定义函数 Cdq(l, r)
表示统计 \([l,r]\) 内的点对。
在 Cdq(l, r)
中,仅考虑横跨 \(mid\) 的点对的贡献。发现对于 \(i\in [l,mid]\) 中的元素,一定有 \(a_i\le a_j, j\in[mid + 1, r]\)。
考虑剩下两个属性 \(b_i, c_i\) 的影响。变成了一个二维数点问题,按照套路,先固定一个元素,再考虑另一个的影响。将 \([l,mid]\) 和 \([mid + 1, r]\) 中元素分别按 \(b_i\) 升序排序。
有了单调性,可用双指针确定 对于每一个 \(j\in [mid + 1, r]\),有多少 \(i\in [l, mid]\) 满足 \(b_i\le b_j\)。
考虑对于每一个 \(j\),求得满足 \(b_i\le b_j\) 的元素 \(i\) 中,还满足 \(c_i\le c_j\) 的元素数量。值域区间统计问题,用树状数组维护,在 \([l, mid]\) 侧的指针右移时插入即可。
上述过程的复杂度为 \(O(n\log n)\)。根据主定理,总复杂度为 \(T(n) = T\left(\dfrac{n}{2}\right) + T\left(\dfrac{n}{2}\right) + O(n\log n) = O(n\log^2 n)\)。
本题是在 \(\le\) 时有贡献,注意去重,将多个元素合并。
代码
//知识点:CDQ分治
/*
By:Luckyblock
*/
#include <algorithm>
#include <cstdio>
#include <ctype.h>
#include <cstring>
#define ll long long
#define lowbit(x) (x&-x)
#define mid ((l_+r_)>>1)
const int kMaxn = 1e5 + 10;
//=============================================================
struct Data {
int x, y, z, cnt, ans;
} a[kMaxn];
int n, k, ans[kMaxn], t[kMaxn << 1];
//=============================================================
inline int read() {
int f = 1, w = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar())
if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
void GetMax(int &fir_, int sec_) {
if (sec_ > fir_) fir_ = sec_;
}
void GetMin(int &fir_, int sec_) {
if (sec_ < fir_) fir_ = sec_;
}
bool Compare1(Data fir_, Data sec_) {
if (fir_.x != sec_.x) return fir_.x < sec_.x;
if (fir_.y != sec_.y) return fir_.y < sec_.y;
return fir_.z < sec_.z;
}
bool Compare2(Data fir_, Data sec_) {
if (fir_.y != sec_.y) return fir_.y < sec_.y;
return fir_.z < sec_.z;
}
void Add(int pos_, int val_) {
for (; pos_ <= k; pos_ += lowbit(pos_)) {
t[pos_] += val_;
}
}
int Sum(int pos_) {
int ret = 0;
for (; pos_; pos_ -= lowbit(pos_)) ret += t[pos_];
return ret;
}
void Cdq(int l_, int r_) {
if (l_ == r_) return ;
Cdq(l_, mid), Cdq(mid + 1, r_);
std :: sort(a + l_, a + mid + 1, Compare2);
std :: sort(a + mid + 1, a + r_ + 1, Compare2);
int i = l_, j = mid + 1;
for (; j <= r_; ++ j) {
for (; a[i].y <= a[j].y && i <= mid; ++ i) {
Add(a[i].z, a[i].cnt);
}
a[j].ans += Sum(a[j].z);
}
for (j = l_; j < i; ++ j) Add(a[j].z, - a[j].cnt);
}
bool JudgeEqual(Data fir_, Data sec_) {
if (fir_.x != sec_.x) return false;
if (fir_.y != sec_.y) return false;
return fir_.z == sec_.z;
}
//=============================================================
int main() {
int tmpn = n = read();
k = read();
for (int i = 1; i <= n; ++ i) {
a[i] = (Data) {read(), read(), read()};
}
std :: sort(a + 1, a + n + 1, Compare1);
n = 0;
for (int i = 1, cnt = 0; i <= tmpn; ++ i) {
cnt ++;
if (! JudgeEqual(a[i], a[i + 1])) {
a[++ n] = a[i],
a[n].cnt = cnt,
cnt = 0;
}
}
Cdq(1, n);
for (int i = 1; i <= n; ++ i) {
ans[a[i].ans + a[i].cnt - 1] += a[i].cnt;
}
for (int i = 0; i < tmpn; ++ i) {
printf("%d\n", ans[i]);
}
return 0;
}