偏序问题小结
三维偏序 陌上花开
- 二维偏序: 先双关键字排序,用双指针算法求出横跨两个区间的个数(另外两种可以递归),然后按b归并排序
- 三维偏序: 先三关键字排序,用双指针算法求出横跨两个区间的左端点j的区间:用二维的做法(树状数组)
- 注意相同元素的影响
点击查看代码
#include <stdio.h>
#include <string.h>
#include <algorithm>
const int N = 1e5 + 5, M = 2e5 + 5;
int n, m;
struct Pt {
int x, y, z, c, res; // c表示出现的个数, res表示答案
bool operator < (const Pt &a) const {
return x < a.x || (x == a.x && (y < a.y || (y == a.y && z < a.z)));
}
bool operator == (const Pt &a) const {
return x == a.x && y == a.y && z == a.z;
}
bool operator <= (const Pt &a) const {
return x < a.x || (x == a.x && (y < a.y || (y == a.y && z <= a.z)));
}
} a[N], tmp[N];
int tr[M]; // 树状数组
void add(int x, int y) {
for(; x <= m; x += x & -x) tr[x] += y;
}
int query(int x) {
int res = 0;
for(; x; x -= x & -x) res += tr[x];
return res;
}
void merge_sort(int l, int r) {
if(l >= r) return;
int mid = (l + r) >> 1;
merge_sort(l, mid), merge_sort(mid + 1, r);
int i = l, j = mid + 1, k = l;
while(k <= r)
if((i <= mid && a[i].y <= a[j].y) || j > r)
add(a[i].z, a[i].c), tmp[k ++] = a[i ++];
else // 统计答案时这段区间已经按x和y排好序了,使用树状数组统计答案
a[j].res += query(a[j].z), tmp[k ++] = a[j ++];
for(i = l; i <= mid; i ++) add(a[i].z, -a[i].c); // 清空树状数组
for(k = l; k <= r; k ++) a[k] = tmp[k];
}
int ans[N];
int main() {
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i ++)
scanf("%d%d%d", &a[i].x, &a[i].y, &a[i].z), a[i].c = 1;
std::sort(a + 1, a + n + 1);
int k = 1;
for(int i = 2; i <= n; i ++)
if(a[i] == a[k]) a[k].c ++;
else a[++ k] = a[i];
merge_sort(1, k);
for(int i = 1; i <= k; i ++) ans[a[i].res + a[i].c - 1] += a[i].c;
for(int i = 0; i < n; i ++) printf("%d\n", ans[i]);
return 0;
}
四维偏序(时空),分治套分治套树状数组
点击查看代码
#include <stdio.h> //
#include <string.h>
#include <algorithm>
const int N = 4e5 + 5;
int T, n, tot, qtot;
struct Pt {
int x, y, z, s:4, t:4, lft:4, id; // (x,y,z):坐标(已经按时间排好序),s:符号,t:是否为询问,lft:是否在左区间
} a[N], tmp1[N], tmp2[N];
int num[N], cnt, ans[N], tr[N]; // z离散化的结果;答案;z的树状数组
inline void inc(int x) { for(; x <= cnt; x += x & -x) ++ tr[x]; } // ++
inline void dec(int x) { for(; x <= cnt; x += x & -x) -- tr[x]; } // --
inline int query(int x) {
int res = 0;
for(; x; x -= x & -x) res += tr[x];
return res;
}
void solve2(int l, int r) {
if(l >= r) return;
int mid = (l + r) >> 1;
solve2(l, mid), solve2(mid + 1, r);
int i = l, j = mid + 1, k = l;
while(k <= r)
if((i <= mid && tmp1[i].y <= tmp1[j].y) || j > r) {
if(tmp1[i].lft && !tmp1[i].t) inc(tmp1[i].z);
tmp2[k ++] = tmp1[i ++];
} else {
if(!tmp1[j].lft && tmp1[j].t) ans[tmp1[j].id] += tmp1[j].s * query(tmp1[j].z);
tmp2[k ++] = tmp1[j ++];
}
for(k = l; k <= r; k ++) {
if(k <= mid && tmp1[k].lft && !tmp1[k].t) dec(tmp1[k].z);
tmp1[k] = tmp2[k];
}
}
void solve1(int l, int r) {
if(l >= r) return;
int mid = (l + r) >> 1;
solve1(l, mid), solve1(mid + 1, r);
int i = l, j = mid + 1, k = l;
while(k <= r)
if((i <= mid && a[i].x <= a[j].x) || j > r) (tmp1[k ++] = a[i ++]).lft = true;
else (tmp1[k ++] = a[j ++]).lft = false;
for(k = l; k <= r; k ++) a[k] = tmp1[k];
solve2(l, r);
}
int main() {
scanf("%d", &T);
while(T --) {
scanf("%d", &n), tot = qtot = cnt = 0;
for(int i = 1, op, x1, y1, z1, x2, y2, z2; i <= n; i ++) {
scanf("%d%d%d%d", &op, &x1, &y1, &z1);
if(op == 1) a[++ tot] = {x1, y1, z1, 0, 0, 0, 0};
else {
scanf("%d%d%d", &x2, &y2, &z2), qtot ++, x1 --, y1 --, z1 --;
a[++ tot] = {x2, y2, z2, 1, 1, 0, qtot}, a[++ tot] = {x1, y1, z1, -1, 1, 0, qtot};
a[++ tot] = {x1, y1, z2, 1, 1, 0, qtot}, a[++ tot] = {x2, y2, z1, -1, 1, 0, qtot};
a[++ tot] = {x1, y2, z1, 1, 1, 0, qtot}, a[++ tot] = {x2, y1, z2, -1, 1, 0, qtot};
a[++ tot] = {x2, y1, z1, 1, 1, 0, qtot}, a[++ tot] = {x1, y2, z2, -1, 1, 0, qtot};
}
}
for(int i = 1; i <= tot; i ++) num[i] = a[i].z;
std::sort(num + 1, num + tot + 1), cnt = std::unique(num + 1, num + tot + 1) - num - 1;
for(int i = 1; i <= tot; i ++) a[i].z = std::lower_bound(num + 1, num + cnt + 1, a[i].z) - num;
solve1(1, tot);
for(int i = 1; i <= qtot; i ++) printf("%d\n", ans[i]), ans[i] = 0;
}
return 0;
}
输入n个m维点。对于每个点,输出有多少个其他点,每一维都不超过它。
点击查看代码
#include <bitset>
#include <math.h>
#include <stdio.h>
#include <string.h>
#include <algorithm>
const int N = 4e4 + 4, K = 2e2 + 2, M = 12;
int n, m, k;
struct Data {
int a[M];
bool operator < (const Data &x) {
for(int i = 1; i <= m; i ++)
if(a[i] != x.a[i]) return a[i] < x.a[i];
return false;
}
} a[N];
int rnk[M][N];
std::bitset<N> set[M][K];
int main() {
scanf("%d%d", &n, &m), k = sqrt(n);
for(int i = 1; i <= n; i ++)
for(int j = 1; j <= m; j ++) scanf("%d", &a[i].a[j]);
for(int j = 1; j <= m; j ++) {
for(int i = 1; i <= n; i ++) rnk[j][i] = i;
auto cmp = [&](int x, int y) { return a[x].a[j] < a[y].a[j]; };
std::sort(rnk[j] + 1, rnk[j] + n + 1, cmp);
std::bitset<N> now;
for(int i = 1; i <= n; i ++) {
now.set(rnk[j][i]);
if(!(i % k)) set[j][i / k] = now;
}
}
for(int i = 1; i <= n; i ++) {
std::bitset<N> now;
now.set();
for(int j = 1; j <= m; j ++) {
int l = 1, r = n;
while(l < r) {
int mid = (l + r + 1) >> 1;
if(a[rnk[j][mid]].a[j] <= a[i].a[j]) l = mid;
else r = mid - 1;
}
std::bitset<N> tmp = set[j][l / k];
for(int t = l / k * k + 1; t <= l; t ++) tmp.set(rnk[j][t]);
now &= tmp;
}
printf("%d\n", int(now.count()) - 1);
}
return 0;
}