AtCoder Arc030_4 グラフではない 题解 可持久化平衡树/可持久化FHQ-Treap
题目链接:https://atcoder.jp/contests/arc030/tasks/arc030_4
题目大意:vjudge链接
现有一个长度为 \(N\) 的数列 \(X={x_1,x_2,...,x_N}\)。需要对数列执行 \(Q\) 次查询操作。查询操作共有\(3\) 种类型,具体如下:
1 a b v
― 对数列 \(X\) 的区间 \([\ a,b\ ]\) 统一加上数值 \(v\)。2 a b c d
― 将数列 \(X\) 的区间 \([\ a,b\ ]\) 替换为当前查询时刻区间 \([\ c,d\ ]\) 的值。\(b-a=d-c\) 保证成立。更准确地说,设此查询得到的新数列为 \(X'\),则有 \(X'_{a}=X_c\),\(X'_{a+1}=X_{c+1}\),…,\(X'_{b}=X_{d}\)。\([\ a,b\ ]\) 范围外的 \(j\) 保持 \(X'_j=X_j\) 不变。3 a b
― 计算数列 \(X\) 区间 \([\ a,b\ ]\) 内所有数值的总和。
你需要编写程序按顺序处理这些查询操作。
解题思路:
可持久化平衡树(这里平衡树用的是 FHQ Treap)。
对于 1 操作:
将区间分裂成 \([1, a-1]\),\([a, b]\),\([b+1, n]\),然后再对 \([a, b]\) 加上 \(c\) 即可。
对于 2 操作:
要两次分裂,但是因为有历史版本,所以:
- 第一次分裂得到 \([c, d]\) 这部分;
- 第二次分裂得到 \([1, a-1]\),\([a, b]\),\([b+1, n]\)
再将 \([1, a-1]\),\([c, d]\),\([b+1, n]\) 三部分合并得到当前版本。
对于 3 操作:
将区间分裂成 \([1, a-1]\),\([a, b]\),\([b+1, n]\),求 \([a, b]\) 这一部分对应的区间和。
示例程序:
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5 + 5, maxm = 2e7 + 5;
mt19937 rng(time(0));
struct Node {
int ls, rs, sz, pri;
long long key, sum, lazy;
Node() {}
Node(long long _key) {
key = sum = _key;
ls = rs = 0;
lazy = 0;
sz = 1;
pri = rng();
}
} tr[maxm];
int rt, idx;
int n, q;
long long a[maxn];
int cpy(int u, int v) {
tr[u] = tr[v];
tr[u].pri = rng();
return u;
}
void push_up(int u) {
int ls = tr[u].ls, rs = tr[u].rs;
tr[u].sz = tr[ls].sz + tr[rs].sz + 1;
tr[u].sum = tr[ls].sum + tr[rs].sum + tr[u].key;
}
void add(int u, long long val) {
tr[u].key += val;
tr[u].sum += tr[u].sz * val;
tr[u].lazy += val;
}
void push_down(int u) {
if (tr[u].lazy) {
if (tr[u].ls) {
tr[u].ls = cpy(++idx, tr[u].ls);
}
if (tr[u].rs) {
tr[u].rs = cpy(++idx, tr[u].rs);
}
add(tr[u].ls, tr[u].lazy);
add(tr[u].rs, tr[u].lazy);
tr[u].lazy = 0;
}
}
void split(int u, int k, int &L, int &R) { // 按rank分割
if (!u) {
L = R = 0;
return;
}
u = cpy(++idx, u);
push_down(u);
int ls = tr[u].ls;
if (tr[ls].sz + 1 <= k) {
L = u;
split(tr[u].rs, k - tr[ls].sz - 1, tr[u].rs, R);
}
else {
R = u;
split(tr[u].ls, k, L, tr[u].ls);
}
push_up(u);
}
int merge(int L, int R) {
if (!L || !R) {
if (L + R == 0) return 0;
return cpy(++idx, L+R);
}
if (tr[L].pri > tr[R].pri) {
push_down(L);
tr[L].rs = merge(tr[L].rs, R);
push_up(L);
return L;
}
else {
push_down(R);
tr[R].ls = merge(L, tr[R].ls);
push_up(R);
return R;
}
}
int cc;
void dfs(int u) { // 重构
if (!u) return;
push_down(u);
dfs(tr[u].ls);
a[++cc] = tr[u].key;
dfs(tr[u].rs);
push_up(u);
}
int build(int l, int r) {
if (l > r) return 0;
int mid = (l + r) >> 1;
int u = ++idx;
tr[u] = Node(a[mid]);
tr[u].ls = build(l, mid-1);
tr[u].rs = build(mid+1, r);
push_up(u);
return u;
}
void init() {
idx = 0;
rt = build(1, n);
}
int main() {
scanf("%d%d", &n, &q);
for (int i = 1; i <= n; i++)
scanf("%lld", a+i);
init();
while (q--) {
int op, a, b;
scanf("%d%d%d", &op, &a, &b);
if (op == 1) {
int v, L, p, R;
scanf("%d", &v);
split(rt, a-1, L, R);
split(R, b-a+1, p, R);
p = cpy(++idx, p);
add(p, v);
rt = merge(merge(L, p), R);
}
else if (op == 2) {
int c, d, L, R, p1, p2;
scanf("%d%d", &c, &d);
split(rt, d, p1, R);
split(p1, c-1, L, p1); // p1: [c,d]
split(rt, b, p2, R); // R: [b+1, n]
split(rt, a-1, L, p2); // L: [1, a-1]
rt = merge(merge(L, p1), R);
}
else {
int L, p, R;
split(rt, b, p, R);
split(p, a-1, L, p);
printf("%lld\n", tr[p].sum);
rt = merge(merge(L, p), R);
}
if (idx > 1e7) {
cc = 0;
dfs(rt);
assert(cc == n);
init();
}
}
return 0;
}