珂朵莉树
概念
\(/bx\ lxl\)
珂朵莉树是一种“基于数据随机的颜色段均摊”,通过 set
维护区间。其复杂度依赖于 assign
操作和数据随机。
使用 set
实现的珂朵莉树时间复杂度为 \(\mathcal{O}(n \log \log n)\),使用链表实现的珂朵莉树时间复杂度为 \(\mathcal{O}(n \log n)\),详见 珂朵莉树的复杂度分析
思路
用 set
存储结构体代表区间。set
中所有元素对应的区间互不相交,取并得到的区间为 \([1, n]\)
若 \(l_1 < l_2\),则 \([l_1, r_1]\) 对应的结构体元素小于 \([l_2, r_2]\) 对应的结构体元素
建树
珂朵莉树无固定建树方式,通常将需要的区间插入 set
。
以 \(\tt CF896C\) 为例,只需 \(\forall 1 \leq i \leq n\),将区间 \([i, i]\) 插入 set
即可。
for (int i = 1; i <= n; i++) {
a[i] = (rnd() % vmax) + 1;
st.insert(node(i, i, a[i]));
}
操作
珂朵莉树的主要操作:assign
,split
。
split(pos)
操作将珂朵莉树中包含下标 \(pos\) 的区间 \([l, r]\) 分割成 \([l, pos)\) 和 \([pos, r]\) 两部分,并返回 set
中 \([pos, r]\) 对应元素的迭代器。具体实现可以在 set
中二分第一个左端点大于等于 \(pos\) 的区间 \([s, t]\),若 \(s = pos\) 则直接返回即可。反之在 set
中删去该区间与其前驱,表示删去包含 \(pos\) 的区间 \([l, r]\)。最后在 set
中插入区间 \([l, pos)\) 和 \([pos, r]\) 并返回迭代器。
set<node>::iterator split(int pos) {
set<node>::iterator it = st.lower_bound(node(pos));
if ((it != st.end()) && (it -> l == pos)) {
return it;
}
it--;
node temp = *it;
st.erase(it);
st.insert(node(temp.l, pos - 1, temp.val));
return st.insert(node(pos, temp.r, temp.val)).first;
}
assign
操作即为将区间 \([l, r]\) 赋值成 \(x\)。利用 split
操作,先从包含 \(r + 1\) 的区间中分裂出区间 \([r + 1, t_1]\),并令其迭代器为 itr
,再从包含 \(l\) 的区间中分裂出区间 \([l, t_2]\),令其迭代器为 itl
。之后在 set
中从 itl
删除到 itr
的前驱,表示将珂朵莉树中与 \([l, r]\) 存在交集的区间全部删除。最后插入新结构体表示区间 \([l, r]\),赋值为 \(x\) 即可。
注意此处分裂区间为 先右后左,反之可能会导致返回的迭代器被错误删除。例:在 \([1, n]\) 中删除 \([2, n - 1]\)
珂朵莉树的复杂度依赖于 assign
操作。
void assign(int l, int r, long long val) {
set<node>::iterator itr = split(r + 1), itl = split(l);
st.erase(itl, itr);
st.insert(node(l, r, val));
}
珂朵莉树的其他操作大致思路是先 split(r + 1)
,再 split(l)
,得到表示 \([l, r]\) 的所有区间,再暴力扫描这些区间进行维护。
例如区间加:
void update(int l, int r, long long x) {
set<node>::iterator itr = split(r + 1), itl = split(l);
for (set<node>::iterator it = itl; it != itr; it++) {
it -> val += x;
}
}
区间求第 \(k\) 小:
long long query_kth(int l, int r, int k) {
vector<pair<long long, int> > v; // v[i].first -> 值, v[i].second -> 值的个数
set<node>::iterator itr = split(r + 1), itl = split(l);
for (set<node>::iterator it = itl; it != itr; it++) {
v.push_back(make_pair(it -> val, it -> r - it -> l + 1));
}
sort(v.begin(), v.end()); // 从最小值开始暴力减去值的个数,得到第 k 小值
for (int i = 0; i < v.size(); i++) {
k -= v[i].second;
if (k <= 0) {
return v[i].first;
}
}
return -1;
}
代码
以 \(\tt CF896C\) 为例。
#include <cstdio>
#include <vector>
#include <set>
#include <iostream>
#include <algorithm>
using namespace std;
const int maxn = 1e5 + 5;
struct node {
int l, r;
mutable long long val;
node(int _l = -1, int _r = -1, long long _val = 0) {
l = _l, r = _r, val = _val;
}
bool operator < (const node& rhs) const {
return l < rhs.l;
}
};
int n, m;
long long a[maxn];
long long seed, vmax;
set<node> st;
long long rnd() {
long long ret = seed;
seed = (seed * 7 + 13) % 1000000007;
return ret;
}
long long fpow(long long base, long long power, long long mod) {
long long res = 1;
base %= mod;
while (power) {
if (power & 1) {
res = res * base % mod;
}
base = base * base % mod;
power >>= 1ll;
}
return res;
}
set<node>::iterator split(int pos) {
set<node>::iterator it = st.lower_bound(node(pos));
if ((it != st.end()) && (it -> l == pos)) {
return it;
}
it--;
node temp = *it;
st.erase(it);
st.insert(node(temp.l, pos - 1, temp.val));
return st.insert(node(pos, temp.r, temp.val)).first;
}
void assign(int l, int r, long long val) {
set<node>::iterator itr = split(r + 1), itl = split(l);
st.erase(itl, itr);
st.insert(node(l, r, val));
}
void update(int l, int r, long long x) {
set<node>::iterator itr = split(r + 1), itl = split(l);
for (set<node>::iterator it = itl; it != itr; it++) {
it -> val += x;
}
}
long long query_kth(int l, int r, int k) {
vector<pair<long long, int> > v;
set<node>::iterator itr = split(r + 1), itl = split(l);
for (set<node>::iterator it = itl; it != itr; it++) {
v.push_back(make_pair(it -> val, it -> r - it -> l + 1));
}
sort(v.begin(), v.end());
for (int i = 0; i < v.size(); i++) {
k -= v[i].second;
if (k <= 0) {
return v[i].first;
}
}
return -1;
}
long long query_psum(int l, int r, long long x, long long y) {
long long sum = 0;
set<node>::iterator itr = split(r + 1), itl = split(l);
for (set<node>::iterator it = itl; it != itr; it++) {
sum = (sum + (it -> r - it -> l + 1) * fpow(it -> val, x, y) % y) % y;
}
return sum;
}
int main() {
int opt, l, r;
long long x, y;
scanf("%d%d%lld%lld", &n, &m, &seed, &vmax);
for (int i = 1; i <= n; i++) {
a[i] = (rnd() % vmax) + 1;
st.insert(node(i, i, a[i]));
}
while (m--) {
opt = (rnd() % 4) + 1;
l = (rnd() % n) + 1;
r = (rnd() % n) + 1;
if (l > r) {
swap(l, r);
}
if (opt == 3) {
x = (rnd() % (r - l + 1)) + 1;
} else {
x = (rnd() % vmax) + 1;
}
if (opt == 4) {
y = (rnd() % vmax) + 1;
}
if (opt == 1) {
update(l, r, x);
} else if (opt == 2) {
assign(l, r, x);
} else if (opt == 3) {
printf("%lld\n", query_kth(l, r, x));
} else {
printf("%lld\n", query_psum(l, r, x, y));
}
}
return 0;
}