[bzoj3196]Tyvj 1730 二逼平衡树——线段树套平衡树
题目
Description
您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:
1.查询k在区间内的排名
2.查询区间内排名为k的值
3.修改某一位值上的数值
4.查询k在区间内的前驱(前驱定义为小于x,且最大的数)
5.查询k在区间内的后继(后继定义为大于x,且最小的数)
Input
第一行两个数 n,m 表示长度为n的有序序列和m个操作
第二行有n个数,表示有序序列
下面有m行,opt表示操作标号
若opt=1 则为操作1,之后有三个数l,r,k 表示查询k在区间[l,r]的排名
若opt=2 则为操作2,之后有三个数l,r,k 表示查询区间[l,r]内排名为k的数
若opt=3 则为操作3,之后有两个数pos,k 表示将pos位置的数修改为k
若opt=4 则为操作4,之后有三个数l,r,k 表示查询区间[l,r]内k的前驱
若opt=5 则为操作5,之后有三个数l,r,k 表示查询区间[l,r]内k的后继
Output
对于操作1,2,4,5各输出一行,表示查询结果
题解
树套树模板题。
对于每一个线段树上的节点,我们都在上面建一个平衡树。
对于操作1,我们把一个长度为len的区间分解为\(\lfloor log_2(len) \rfloor\)个子区间,分别处理,把每个区间的排名加起来就好了。
对于操作2,我们二分答案,对于二分出来的一个答案执行操作1,check以下即可。
对于操作3,我们删除再插入。
对于操作4和操作5,我们把每个区间(前驱/后继)的(最大值/最小值)搞一搞就好啦。
PS.谁能告诉我bzoj上那些2000ms的是怎么做的。。。
代码
#include <algorithm>
#include <cctype>
#include <cstdio>
using namespace std;
const int maxn = 4e6 + 5;
const int inf = 1e9;
int ans, n, m, opt, l, r, k, pos, sz, Max;
int a[maxn], fa[maxn], ch[maxn][2], size[maxn], cnt[maxn], data[maxn], rt[maxn];
inline int read() {
int x = 0, f = 1;
char ch = getchar();
while (!isdigit(ch)) {
if (ch == '-')
f = -1;
ch = getchar();
}
while (isdigit(ch)) {
x = x * 10 + ch - '0';
ch = getchar();
}
return x * f;
}
struct Splay {
void clear(int x) {
fa[x] = ch[x][0] = ch[x][1] = size[x] = cnt[x] = data[x] = 0;
}
void update(int x) {
if (x) {
size[x] = cnt[x];
if (ch[x][0])
size[x] += size[ch[x][0]];
if (ch[x][1])
size[x] += size[ch[x][1]];
}
}
void zig(int x) {
int y = fa[x], z = fa[y], l = (ch[y][1] == x), r = l ^ 1;
fa[ch[y][l] = ch[x][r]] = y;
fa[ch[x][r] = y] = x;
fa[x] = z;
if (z)
ch[z][ch[z][1] == y] = x;
update(y);
update(x);
}
void splay(int i, int x, int aim = 0) {
for (int y; (y = fa[x]) != aim; zig(x))
if (fa[y] != aim)
zig((ch[fa[y]][0] == y) == (ch[y][0] == x) ? y : x);
if (aim == 0)
rt[i] = x;
}
void insert(int i, int v) {
int x = rt[i], y = 0;
if (x == 0) {
rt[i] = x = ++sz;
fa[x] = ch[x][0] = ch[x][1] = 0;
size[x] = cnt[x] = 1;
data[x] = v;
return;
}
while (1) {
if (data[x] == v) {
cnt[x]++;
update(y);
splay(i, x);
return;
}
y = x;
x = ch[x][v > data[x]];
if (x == 0) {
++sz;
fa[sz] = y;
ch[sz][0] = ch[sz][1] = 0;
size[sz] = cnt[sz] = 1;
data[sz] = v;
ch[y][v > data[y]] = sz;
update(y);
splay(i, sz);
rt[i] = sz;
return;
}
}
}
void find(int i, int v) {
int x = rt[i];
while (1) {
if (data[x] == v) {
splay(i, x);
return;
} else
x = ch[x][v > data[x]];
}
}
int pre(int i) {
int x = ch[rt[i]][0];
while (ch[x][1])
x = ch[x][1];
return x;
}
int succ(int i) {
int x = ch[rt[i]][1];
while (ch[x][0])
x = ch[x][0];
return x;
}
void del(int i) {
int x = rt[i];
if (cnt[x] > 1) {
cnt[x]--;
return;
}
if (!ch[x][0] && !ch[x][1]) {
clear(rt[i]);
rt[i] = 0;
return;
}
if (!ch[x][0]) {
int oldroot = x;
rt[i] = ch[x][1];
fa[rt[i]] = 0;
clear(oldroot);
return;
}
if (!ch[x][1]) {
int oldroot = x;
rt[i] = ch[x][0];
fa[rt[i]] = 0;
clear(oldroot);
return;
}
int y = pre(i), oldroot = x;
splay(i, y);
rt[i] = y;
ch[rt[i]][1] = ch[oldroot][1];
fa[ch[oldroot][1]] = rt[i];
clear(oldroot);
update(rt[i]);
return;
}
int rank(int i, int v) {
int x = rt[i], ans = 0;
while (1) {
if (!x)
return ans;
if (data[x] == v)
return ((ch[x][0]) ? size[ch[x][0]] : 0) + ans;
else if (data[x] < v) {
ans += ((ch[x][0]) ? size[ch[x][0]] : 0) + cnt[x];
x = ch[x][1];
} else if (data[x] > v) {
x = ch[x][0];
}
}
}
int find_pre(int i, int v) {
int x = rt[i];
while (x) {
if (data[x] < v) {
if (ans < data[x])
ans = data[x];
x = ch[x][1];
} else
x = ch[x][0];
}
return ans;
}
int find_succ(int i, int v) {
int x = rt[i];
while (x) {
if (v < data[x]) {
if (ans > data[x])
ans = data[x];
x = ch[x][0];
} else
x = ch[x][1];
}
return ans;
}
} sp;
void insert(int k, int l, int r, int x, int v) {
int mid = (l + r) >> 1;
sp.insert(k, v);
if (l == r)
return;
if (x <= mid)
insert(k << 1, l, mid, x, v);
else
insert(k << 1 | 1, mid + 1, r, x, v);
}
void askrank(int k, int l, int r, int x, int y, int val) {
int mid = (l + r) >> 1;
if (x <= l && r <= y) {
ans += sp.rank(k, val);
return;
}
if (x <= mid)
askrank(k << 1, l, mid, x, y, val);
if (mid + 1 <= y)
askrank(k << 1 | 1, mid + 1, r, x, y, val);
}
void change(int k, int l, int r, int pos, int val) {
int mid = (l + r) >> 1;
sp.find(k, a[pos]);
sp.del(k);
sp.insert(k, val);
if (l == r)
return;
if (pos <= mid)
change(k << 1, l, mid, pos, val);
else
change(k << 1 | 1, mid + 1, r, pos, val);
}
void askpre(int k, int l, int r, int x, int y, int val) {
int mid = (l + r) >> 1;
if (x <= l && r <= y) {
ans = max(ans, sp.find_pre(k, val));
return;
}
if (x <= mid)
askpre(k << 1, l, mid, x, y, val);
if (mid + 1 <= y)
askpre(k << 1 | 1, mid + 1, r, x, y, val);
}
void asksucc(int k, int l, int r, int x, int y, int val) {
int mid = (l + r) >> 1;
if (x <= l && r <= y) {
ans = min(ans, sp.find_succ(k, val));
return;
}
if (x <= mid)
asksucc(k << 1, l, mid, x, y, val);
if (mid + 1 <= y)
asksucc(k << 1 | 1, mid + 1, r, x, y, val);
}
int main() {
#ifdef D
freopen("input", "r", stdin);
#endif
n = read(), m = read();
for (int i = 1; i <= n; i++)
a[i] = read(), Max = max(Max, a[i]), insert(1, 1, n, i, a[i]);
for (int i = 1; i <= m; i++) {
opt = read();
if (opt == 1) {
l = read(), r = read(), k = read();
ans = 0;
askrank(1, 1, n, l, r, k);
printf("%d\n", ans + 1);
} else if (opt == 2) {
l = read(), r = read(), k = read();
int head = 0, tail = Max + 1;
while (head != tail) {
int mid = (head + tail) >> 1;
ans = 0;
askrank(1, 1, n, l, r, mid);
if (ans < k)
head = mid + 1;
else
tail = mid;
}
printf("%d\n", head - 1);
} else if (opt == 3) {
pos = read();
k = read();
change(1, 1, n, pos, k);
a[pos] = k;
Max = std::max(Max, k);
} else if (opt == 4) {
l = read();
r = read();
k = read();
ans = 0;
askpre(1, 1, n, l, r, k);
printf("%d\n", ans);
} else if (opt == 5) {
l = read();
r = read();
k = read();
ans = inf;
asksucc(1, 1, n, l, r, k);
printf("%d\n", ans);
}
}
}