【luogu P5494】【模板】线段树分裂(线段树合并)(线段树分裂)
【模板】线段树分裂
题目链接:luogu P5494
题目大意
一开始给你第一个可重集,然后它在 1~n 的范围内,然后初始每个数有一定的个数。
然后给你一些操作:
给某个可重集放几个某个数字;询问某个可重集有多少个数在 l~r 之间;询问某个可重集第 k 小的数(如果没有就是 -1);将某个可重集的数全部放入另一个可重集中,并清空它(保证之后都不会用到它);将一个可重集的 l~r 之间的数字全部放入一个新的可重集中,这个可重集的编号是当前可重集的个数。
请维护。
思路
很明显,我们只需要知道如何线段树合并和线段树分裂。
首先是线段树合并,它其实就是你只要把两个线段树同时从根节点开始跑,然后如果两个的某个位置都有值就合并,然后继续分别跑儿子。
有其中一个没有值或两个都没有就返回剩下有的那个或者空。
接着是线段树分裂,线段树分裂我们其实可以参考无旋 Treap 的分裂方式。
其实就是看分割的位置。
如果当前用作分割的第 k 大在左边(根据左儿子大小判断),那说明右边都是给第二个线段树的,然后左边继续分割。
如果在右边,说明左边都是给第一个线段树的,然后右边继续分割。
然后就可以啦。
代码
#include<cstdio>
#include<algorithm>
#define ll long long
using namespace std;
struct node {
int l, r;
ll x;
}t[200001 << 6];
int n, m, rt[200001];
int bin[200001 << 6], tot, tmp;
int op, x, y, z;
int get_new() {
if (!bin[0]) return ++tot;
return bin[bin[0]--];
}
void clear(int now) {//回收空间
bin[++bin[0]] = now;
t[now].l = t[now].r = 0;
t[now].x = 0;
}
void add(int &now, int l, int r, int pl, ll val) {
if (!now) now = get_new();
t[now].x += val;
if (l == r) return ;
int mid = (l + r) >> 1;
if (pl <= mid) add(t[now].l, l, mid, pl, val);
else add(t[now].r, mid + 1, r, pl, val);
}
int merge(int x, int y) {
if (!x || !y) return x + y;
t[x].x += t[y].x;
t[x].l = merge(t[x].l, t[y].l);
t[x].r = merge(t[x].r, t[y].r);
clear(y);
return x;
}
void split(int x, int &y, ll k) {
if (!x) return ;
y = get_new();
ll lnum = t[t[x].l].x;//记得开 long long
if (k <= lnum) {
swap(t[x].r, t[y].r);//记得右边整块都要挪过去
if (k < lnum) split(t[x].l, t[y].l, k);
}
else split(t[x].r, t[y].r, k - lnum);
t[y].x = t[x].x - k;
t[x].x = k;
}
ll query(int now, int l, int r, int L, int R) {
if (L <= l && r <= R) return t[now].x;
int mid = (l + r) >> 1;
ll re = 0;//记得开 long long
if (L <= mid) re += query(t[now].l, l, mid, L, R);
if (mid < R) re += query(t[now].r, mid + 1, r, L, R);
return re;
}
int find_k(int now, int l, int r, int k) {
if (l == r) return l;
int mid = (l + r) >> 1;
if (t[t[now].l].x >= k) return find_k(t[now].l, l, mid, k);
else return find_k(t[now].r, mid + 1, r, k - t[t[now].l].x);
}
int main() {
// freopen("read.txt", "r", stdin);
rt[0] = 1;
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; i++) {
scanf("%d", &x);
add(rt[1], 1, n, i, x);
}
while (m--) {
scanf("%d", &op);
if (op == 0) {
scanf("%d %d %d", &x, &y, &z);
ll num1 = query(rt[x], 1, n, 1, z);
ll num2 = query(rt[x], 1, n, y, z);
split(rt[x], rt[++rt[0]], num1 - num2);
split(rt[rt[0]], tmp, num2);
rt[x] = merge(rt[x], tmp);
}
if (op == 1) {
scanf("%d %d", &x, &y);
rt[x] = merge(rt[x], rt[y]);
}
if (op == 2) {
scanf("%d %d %d", &x, &y, &z);
add(rt[x], 1, n, z, y);
}
if (op == 3) {
scanf("%d %d %d", &x, &y, &z);
printf("%lld\n", query(rt[x], 1, n, y, z));
}
if (op == 4) {
scanf("%d %d", &x, &y);
if (t[rt[x]].x < y) printf("-1\n");
else printf("%d\n", find_k(rt[x], 1, n, y));
}
}
return 0;
}