【主席树】洛谷 P3834 可持久化线段树 2
【主席树】洛谷 P3834 可持久化线段树2
题目链接:https://www.luogu.com.cn/problem/P3834
主席树是可持久化线段树的一种,也叫做可持久化权值线段树,主要可以用来O(logn)求静态区间的第k小数。
总所周知,普通线段树每次修改会遍历logn个点,那么我们在每次修改时都把这logn个点复制一份出来再修改,生成一个历史版本,就是可持久化线段树了,这里每一个点都是动态开点,而不是提前开好的,所以每个点内要存他的左右儿子节点的编号,不再是传统线段树的 2*p 和 2*p+1。
主席树有两种写法,一个是提前分配好所有空间,一个是使用指针。
需要注意的是,指针的写法时空复杂度一般要比普通写法大一倍,在这道洛谷板题的具体表现就是:普通写法(634ms,40.79MB),指针写法(1.18s,101.34MB)。
所以比赛时还是用普通写法稳妥一点,空间开到 N<<6 就保证够了。
代码(提前分配空间)
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
constexpr int N = 2e5;
struct node {
int l, r;
int sum;
} tr[N << 6];
int cnt;
int add(int p, int l, int r, int x) {
int u = ++cnt;
tr[u] = tr[p];
tr[u].sum++;
if (l == r) return u;
int m = (l + r) / 2;
if (x <= m) {
tr[u].l = add(tr[u].l, l, m, x);
} else {
tr[u].r = add(tr[u].r, m + 1, r, x);
}
return u;
}
int query(int p, int q, int l, int r, int k) {
if (l == r) return l;
int m = (l + r) / 2;
int x = tr[tr[q].l].sum - tr[tr[p].l].sum;
if (x >= k) {
return query(tr[p].l, tr[q].l, l, m, k);
} else {
return query(tr[p].r, tr[q].r, m + 1, r, k - x);
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n, m;
cin >> n >> m;
vector<int> a(n);
for (int i = 0; i < n; i++) {
cin >> a[i];
}
auto b = a;
sort(b.begin(), b.end());
b.erase(unique(b.begin(), b.end()), b.end());
int tot = b.size();
auto getid = [&](int x) {
return lower_bound(b.begin(), b.end(), x) - b.begin();
};
vector<int> rt(n + 1);
for (int i = 0; i < n; i++) {
rt[i + 1] = add(rt[i], 0, tot - 1, getid(a[i]));
}
for (int i = 0; i < m; i++) {
int l, r, k;
cin >> l >> r >> k;
l--, r--;
int id = query(rt[l], rt[r + 1], 0, tot - 1, k);
cout << b[id] << '\n';
}
return 0;
}
代码(指针)
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
struct node {
node *l;
node *r;
int sum;
node() : l{}, r{}, sum{} {}
};
node *add(node *p, int l, int r, int x) {
node *n = new node();
if (p) *n = *p;
n->sum++;
if (l == r) return n;
int m = (l + r) / 2;
if (x <= m) {
n->l = add(n->l, l, m, x);
} else {
n->r = add(n->r, m + 1, r, x);
}
return n;
}
int query(node *p, node *q, int l, int r, int k) {
if (l == r) return l;
int nq = (q && q->l ? q->l->sum : 0);
int np = (p && p->l ? p->l->sum : 0);
int num = nq - np;
int m = (l + r) / 2;
if (num >= k) {
return query(p ? p->l : nullptr, q ? q->l : nullptr, l, m, k);
} else {
return query(p ? p->r : nullptr, q ? q->r : nullptr, m + 1, r, k - num);
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n, m;
cin >> n >> m;
vector<int> a(n);
for (int i = 0; i < n; i++) {
cin >> a[i];
}
auto b = a;
sort(b.begin(), b.end());
b.erase(unique(b.begin(), b.end()), b.end());
int tot = b.size();
auto getid = [&](int x) {
return lower_bound(b.begin(), b.end(), x) - b.begin();
};
vector<node *> rt(n + 1);
for (int i = 0; i < n; i++) {
rt[i + 1] = add(rt[i], 0, tot - 1, getid(a[i]));
}
for (int i = 0; i < m; i++) {
int l, r, k;
cin >> l >> r >> k;
l--, r--;
int id = query(rt[l], rt[r + 1], 0, tot - 1, k);
cout << b[id] << '\n';
}
return 0;
}