二逼平衡树
Problem
您需要写一种数据结构,来维护一个有序数列,其中需要提供以下操作:
1、查询 \(k\) 在区间内的排名
2、查询区间内排名为 \(k\) 的值
3、修改某一位值上的数值
4、查询 \(k\) 在区间内的前驱(前驱定义为严格小于 \(x\),且最大的数,若不存在输出 \(-2147483647\))
5、查询 \(k\) 在区间内的后继(后继定义为严格大于 \(x\),且最小的数,若不存在输出 \(2147483647\))
Sol1
采用线段树套平衡树,线段树维护区间,平衡树维护区间内元素信息(大小关系)。对于操作 2 会达到 \(\mathcal O(\log^3n)\),其余均为 \(\mathcal O(\log^2n)\)。
#include <bits/stdc++.h>
const int N = 50005;
int n, m;
namespace BST {
int ch[N*30][2], r[N*30], key[N*30], sz[N*30], tot = 0;
int newnode(int v) {
tot++;
r[tot] = rand(), key[tot] = v, sz[tot] = 1;
return tot;
}
void pushup(int o) {
sz[o] = sz[ch[o][0]] + sz[ch[o][1]] + 1;
}
void split(int o, int v, int &l, int &r) {
if (!o) { l = r = 0; return; }
if (v > key[o]) l = o, split(ch[o][1], v, ch[l][1], r), pushup(l);
else r = o, split(ch[o][0], v, l, ch[r][0]), pushup(r);
}
int merge(int x, int y) {
if (!x || !y) return x | y;
return r[x] > r[y] ? (ch[x][1] = merge(ch[x][1], y), pushup(x), x) : (ch[y][0] = merge(x, ch[y][0]), pushup(y), y);
}
int kth(int o, int k) {
for (;;) {
int t = sz[ch[o][0]] + 1;
if (k == t) return key[o];
k > t ? (k -= t, o = ch[o][1]) : o = ch[o][0];
}
}
void ins(int &rt, int v) {
int x, y; split(rt, v, x, y);
rt = merge(merge(x, newnode(v)), y);
}
void del(int &rt, int v) {
int x, y, z; split(rt, v, x, y); split(y, v+1, y, z);
rt = merge(merge(x, merge(ch[y][0], ch[y][1])), z);
}
int rank(int &rt, int v) {
int x, y, ans; split(rt, v, x, y); ans = sz[x]; rt = merge(x, y);
return ans;
}
int pre(int &rt, int v) {
int x, y, ans; split(rt, v, x, y); ans = x ? kth(x, sz[x]) : -2147483647; rt = merge(x, y);
return ans;
}
int suc(int &rt, int v) {
int x, y, ans; split(rt, v+1, x, y); ans = y ? kth(y, 1) : 2147483647; rt = merge(x, y);
return ans;
}
};
#define lc (o << 1)
#define rc (o << 1 | 1)
int rt[N*4], a[N];
void build(int o, int l, int r) {
for (int i = l; i <= r; i++) BST::ins(rt[o], a[i]);
if (l == r) return;
int mid = l+r>>1;
build(lc, l, mid), build(rc, mid+1, r);
}
void modify(int o, int l, int r, int p, int v) {
BST::del(rt[o], a[p]); BST::ins(rt[o], v);
if (l == r) { a[p] = v; return; }
int mid = l+r>>1;
p <= mid ? modify(lc, l, mid, p, v) : modify(rc, mid+1, r, p, v);
}
int qry_rank(int o, int l, int r, int x, int y, int v) {
if (r < x || y < l) return 0;
if (x <= l && r <= y) return BST::rank(rt[o], v);
int mid = l+r>>1;
return qry_rank(lc, l, mid, x, y, v) + qry_rank(rc, mid+1, r, x, y, v);
}
int qry_kth(int x, int y, int k) {
int l = 0, r = 1e8;
while (l <= r) {
int mid = l+r>>1;
if (qry_rank(1, 1, n, x, y, mid) < k) l = mid+1; else r = mid-1;
}
return l-1;
}
int qry_pre(int o, int l, int r, int x, int y, int v) {
if (r < x || y < l) return -2147483647;
if (x <= l && r <= y) return BST::pre(rt[o], v);
int mid = l+r>>1;
return std::max(qry_pre(lc, l, mid, x, y, v), qry_pre(rc, mid+1, r, x, y, v));
}
int qry_suc(int o, int l, int r, int x, int y, int v) {
if (r < x || y < l) return 2147483647;
if (x <= l && r <= y) return BST::suc(rt[o], v);
int mid = l+r>>1;
return std::min(qry_suc(lc, l, mid, x, y, v), qry_suc(rc, mid+1, r, x, y, v));
}
int main() {
srand(time(0));
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
build(1, 1, n);
for (int i = 1; i <= m; i++) {
int op, l, r, k; scanf("%d", &op);
if (op == 1) scanf("%d%d%d", &l, &r, &k), printf("%d\n", qry_rank(1, 1, n, l, r, k) + 1);
if (op == 2) scanf("%d%d%d", &l, &r, &k), printf("%d\n", qry_kth(l, r, k));
if (op == 3) scanf("%d%d", &l, &r), modify(1, 1, n, l, r);
if (op == 4) scanf("%d%d%d", &l, &r, &k), printf("%d\n", qry_pre(1, 1, n, l, r, k));
if (op == 5) scanf("%d%d%d", &l, &r, &k), printf("%d\n", qry_suc(1, 1, n, l, r, k));
}
return 0;
}
Sol2
采用树状数组套权值线段树,先维护权值,再维护区间。所有操作复杂度 \(\mathcal O(\log^2n)\)。
#include <bits/stdc++.h>
const int N = 50005, INF = 2147483647;
namespace SEG {
int L[N*600], R[N*600], sz[N*600], tot = 0;
void modify(int &o, int l, int r, int p, int v) {
if (!o) o = ++tot;
sz[o] += v;
if (l == r) return;
int mid = l+r>>1;
p <= mid ? modify(L[o], l, mid, p, v) : modify(R[o], mid+1, r, p, v);
}
int query(int o, int l, int r, int x, int y) {
if (!o || l > y || r < x) return 0;
if (x <= l && r <= y) return sz[o];
int mid = l+r>>1;
return query(L[o], l, mid, x, y) + query(R[o], mid+1, r, x, y);
}
}
int n, m, M, a[N], disc[N*2], op[N], l[N], r[N], k[N];
#define lowbit(x) (x & (-x))
int rt[N*2];
void modify(int i, int x, int v) {
for (; i <= M; i += lowbit(i))
SEG::modify(rt[i], 1, n, x, v);
}
int query1(int l, int r, int x) {
int ans = 0;
for (; x; x -= lowbit(x))
ans += SEG::query(rt[x], 1, n, l, r);
return ans;
}
int query2(int l, int r, int k) {
int ans = 0, now = 0;
for (int i = 17; ~i; i--) {
ans += 1<<i;
if (ans > M) { ans -= 1<<i; continue; }
int t = SEG::query(rt[ans], 1, n, l, r);
if (now + t >= k) ans -= 1<<i; else now += t;
}
return ans+1;
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]), disc[i] = a[i];
for (int i = 1; i <= m; i++) {
scanf("%d", &op[i]);
if (op[i] == 3) scanf("%d%d", &l[i], &k[i]);
else scanf("%d%d%d", &l[i], &r[i], &k[i]);
disc[n+i] = k[i];
}
std::sort(disc+1, disc+n+m+1); M = std::unique(disc+1, disc+n+m+1) - disc - 1;
for (int i = 1; i <= n; i++) a[i] = std::lower_bound(disc+1, disc+M+1, a[i]) - disc, modify(a[i], i, 1);
for (int i = 1; i <= m; i++) {
if (op[i] != 2) k[i] = std::lower_bound(disc+1, disc+M+1, k[i]) - disc;
if (op[i] == 1) printf("%d\n", query1(l[i], r[i], k[i]-1) + 1);
if (op[i] == 2) printf("%d\n", disc[query2(l[i], r[i], k[i])]);
if (op[i] == 3) modify(a[l[i]], l[i], -1), modify(a[l[i]] = k[i], l[i], 1);
if (op[i] == 4) {
int t = query1(l[i], r[i], k[i]-1);
printf("%d\n", t ? disc[query2(l[i], r[i], t)] : -INF);
}
if (op[i] == 5) {
int t = query1(l[i], r[i], k[i]);
printf("%d\n", t < r[i]-l[i]+1 ? disc[query2(l[i], r[i], t+1)] : INF);
}
}
return 0;
}