[luoguP3810] 三维偏序
题意
有 $ n $ 个元素,第 $ i $ 个元素有 $ a_i,b_i,c_i $ 三个属性,设 $ f(i) $ 表示满足 $ a_j \leq a_i $ 且 $ b_j \leq b_i $ 且 $ c_j \leq c_i $ 且 $ j \ne i $ 的 \(j\) 的数量。
对于 $ d \in [0, n) $,求 $ f(i) = d $ 的数量。
sol
先来考虑类似的二维偏序,即删去 \(c_i\) 这一维属性
二维偏序
二维偏序与逆序对做法基本相同。
先将第一维排序,这样就只需要处理第二维的相对位置即可。类似于逆序对,我们使用归并排序对第二维进行排序,记作 \(\operatorname{merge\_sort}(l,r)\),那么 \(\operatorname{merge\_sort}(l,mid)\) 和 \(\operatorname{merge\_sort}(mid+1,r)\) 就计算出了 \([l,mid]\) 和 \((mid,r]\) 的答案,只需要在归并的过程中计算这两个区间之间的偏序即可。
计算时,记左侧区间指针为 \(i\),右侧区间指针为 \(j\),当 \(i\) 移动时,说明右侧区间中 \([j,r]\) 部分都需要累加上第 \(i\) 个元素的答案,即 \(1\),因此需要使用 \(res\) 变量记录下右侧区间中每个元素的答案需要累加多少。
由于元素可能有重复,因此需要去重,记第 \(i\) 种元素有 \(cnt_i\) 个,则累加答案时需要累加 \(cnt_i\) 个而非 \(1\) 个。
三维偏序
三维偏序做法无明显区别,但在归并时,由于有第三维的存在,因此无法使用 \(res\) 变量来记录累加多少。我们需要一个快速计算出左侧区间中有多少个已被枚举到的值第三维小于等于右侧枚举到的值的第三维,可以使用权值树状数组或权值线段树解决,综合时间复杂度 \(O(n\log^2 n)\)
在这里我们使用了一种使用一个子问题计算另一个子问题的离线分治思想,这种思想就是 CDQ 分治
代码
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
const int N = 100005, M = 200005;
int n, k;
int ans[N], rk[N];
int tr[N];
struct Node{
int a, b, c, cnt, ans;
bool operator< (const Node &W) const {
if (a != W.a) return a < W.a;
if (b != W.b) return b < W.b;
return c < W.c;
}
bool operator!= (const Node &W) const {
return a != W.a || b != W.b || c != W.c;
}
} g[N], temp[N];
int lowbit(int x){
return x & -x;
}
void insert(int p, int x){
for (int i = p; i <= k; i += lowbit(i)) tr[i] += x;
}
int query(int p){
int res = 0;
for (int i = p; i; i -= lowbit(i)) res += tr[i];
return res;
}
void merge_sort(int l, int r){
if (l >= r) return ;
int mid = l + r >> 1;
merge_sort(l, mid), merge_sort(mid + 1, r);
int i = l, j = mid + 1, cnt = 0;
while (i <= mid && j <= r){
while (i <= mid && g[i].b <= g[j].b) {
temp[ ++ cnt] = g[i];
insert(g[i].c, g[i].cnt);
i ++ ;
}
while (j <= r && g[j].b < g[i].b) {
g[j].ans += query(g[j].c);
temp[ ++ cnt] = g[j];
j ++ ;
}
}
while (i <= mid) {
insert(g[i].c, g[i].cnt);
temp[ ++ cnt] = g[i];
i ++ ;
}
while (j <= r) {
g[j].ans += query(g[j].c);
temp[ ++ cnt] = g[j];
j ++ ;
}
for (int i = l; i <= mid; i ++ ) insert(g[i].c, -g[i].cnt);
for (int i = l; i <= r; i ++ ) g[i] = temp[i - l + 1];
}
int main(){
scanf("%d%d", &n, &k);
for (int i = 1; i <= n; i ++ ) scanf("%d%d%d", &g[i].a, &g[i].b, &g[i].c);
sort(g + 1, g + n + 1);
int r = n, c = 0;
n = 0;
for (int i = 1; i <= r; i ++ ){
c ++ ;
if (g[i] != g[i + 1]) {
g[ ++ n] = g[i], g[n].cnt = c;
c = 0;
}
}
merge_sort(1, n);
for (int i = 1; i <= n; i ++ ) ans[g[i].ans + g[i].cnt - 1] += g[i].cnt;
for (int i = 0; i < r; i ++ ) printf("%d\n", ans[i]);
return 0;
}
蒟蒻犯的若至错误
- 树状数组写错了 awa