HNOI2012 永无乡
题意
给定\(n\)个连通块,有两种操作:
- 合并两个连通块
- 查询某个元素所在连通块内第\(k\)大的值
解法
合并连通块\(\to\)启发式合并,查询第\(k\)大\(\to\)平衡树,权值线段树
当然这道题可以用线段树合并写,但是用FHQ_Treap来写实在是太爽了
由于FHQ_Treap本身就可以维护连通块(一颗树就是一个连通块),还能顺带维护连通块的size与其中的元素,简直是为这道题量身定制的
对于合并两个连通块的操作,我们遍历较小的那个连通块,将其中的元素一个个加进大连通块中,并把合并以后得到的根在并查集中设为原来连通块根的祖先
代码
#include <ctime>
#include <cstdio>
#include <cstdlib>
#include <iostream>
using namespace std;
const int N = 1e5 + 10;
int read();
int n, m;
int fa[N], mp[N];
char op[10];
struct FHQ_Treap {
#define ls(x) t[x].ch[0]
#define rs(x) t[x].ch[1]
int cnt;
int a, b;
struct node {
int val, rnd, siz;
int ch[2];
node() { ch[0] = ch[1] = 0; }
} t[N];
FHQ_Treap() : cnt(0) {}
void update(int x) {
t[x].siz = t[ls(x)].siz + t[rs(x)].siz + 1;
}
int newnode(int v) {
++cnt;
t[cnt].val = v, t[cnt].rnd = rand() << 15 | rand(), t[cnt].siz = 1;
return cnt;
}
void split(int x, int k, int& lt, int& rt) {
if (!x)
return lt = rt = 0, void();
if (t[x].val <= k)
lt = x, split(rs(x), k, rs(x), rt);
else
rt = x, split(ls(x), k, lt, ls(x));
update(x);
}
int merge(int x, int y) {
if (!x || !y)
return x | y;
if (t[x].rnd < t[y].rnd) {
rs(x) = merge(rs(x), y);
return update(x), x;
} else {
ls(y) = merge(x, ls(y));
return update(y), y;
}
}
int kth(int x, int k) {
if (k > t[x].siz) return 0;
int p = x;
while (true) {
if (k <= t[ls(p)].siz) p = ls(p);
else if (k == t[ls(p)].siz + 1) return t[p].val;
else k -= t[ls(p)].siz + 1, p = rs(p);
}
}
int insert(int x, int y) {
split(y, t[x].val, a, b);
return merge(merge(a, x), b);
}
void DFS(int x, int& y) {
if (ls(x)) DFS(ls(x), y);
if (rs(x)) DFS(rs(x), y);
y = insert(x, y);
}
int combine(int x, int y) {
DFS(x, y);
return y;
}
#undef ls
#undef rs
} tr;
inline int get(int x) {
return x == fa[x] ? x : fa[x] = get(fa[x]);
}
void modify(int u, int v) {
if (get(u) ^ get(v)) {
u = fa[u], v = fa[v];
if (tr.t[u].siz > tr.t[v].siz) swap(u, v);
int nr = tr.combine(u, v);
fa[u] = fa[v] = fa[nr] = nr;
}
}
int main() {
srand(time(0));
n = read(), m = read();
mp[0] = -1;
for (int i = 1; i <= n; ++i) {
int v = read();
fa[i] = mp[v] = i;
tr.newnode(v);
}
for (int i = 1; i <= m; ++i) modify(read(), read());
int q = read();
while (q--) {
scanf("%s", op + 1);
if (op[1] == 'B') {
modify(read(), read());
} else {
int u = read(), k = read();
printf("%d\n", mp[tr.kth(get(u), k)]);
}
}
return 0;
}
int read() {
int x = 0, c = getchar();
while (c < '0' || c > '9') c = getchar();
while (c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
return x;
}