Luogu P5494 Solution
闲话
在不看题解的情况下独立完成了整体架构,然而 split
部分出现问题,在看到题解后改为了正确的形式。
整体架构用时 20min,调试用时 30min。
题解
本题的 \(2,3,4\) 操作与普通动态开点线段树无异,故重点讲述 \(0,1\) 操作。
操作 \(0\)
区间解离,即线段树分裂的核心操作——分裂。
其思想非常简单,类似普通线段树的更新操作,区间分裂时,碰到完全处于目标范围内的节点直接交换即可。实际上本操作应被视为一个区间的交换而非分裂。
单次操作时间复杂度 \(O(\log n)\)。
void split(int &x, int &pre, int l, int r, int lb, int rb)
{
if (l >= lb and r <= rb)
{
swap(x, pre);
return;
}
if (!x)
x = ++idx;
int mid = (l + r) >> 1;
if (lb <= mid)
split(tr[x].ls, tr[pre].ls, l, mid, lb, rb);
if (rb > mid)
split(tr[x].rs, tr[pre].rs, mid + 1, r, lb, rb);
tr[x].siz = tr[tr[x].ls].siz + tr[tr[x].rs].siz;
tr[pre].siz = tr[tr[pre].ls].siz + tr[tr[pre].rs].siz;
}
操作 \(1\)
整体合并。实际上该操作可以在普通动态开点线段树上完成。不过在本题中,由于保证合并源数组在后续不会再被访问,故目标数组可以直接引用源数组。
均摊总时间复杂度 \(O(n\log n)\)。
void merge(int &x, int &pre)
{
if (!pre)
return;
if (!x)
{
x = pre;
return;
}
tr[x].siz += tr[pre].siz;
merge(tr[x].ls, tr[pre].ls);
merge(tr[x].rs, tr[pre].rs);
}
于是,这道题就做完了。
代码
#include <cctype>
#include <cstdio>
#include <utility>
using namespace std;
const int N = 2e5 + 10;
using ll = long long;
template <typename _Tp> inline void read(_Tp &x)
{
char ch;
while (ch = getchar(), !isdigit(ch))
;
x = (ch ^ 48);
while (ch = getchar(), isdigit(ch))
x = (x << 3) + (x << 1) + (ch ^ 48);
}
template <typename _Tp, typename... _Args> inline void read(_Tp &x, _Args &...args)
{
read(x);
read(args...);
}
template <typename _Tp> void print(_Tp x)
{
if (x < 0)
return putchar('-'), print(-x);
if (x > 9)
print(x / 10);
putchar((x % 10) ^ 48);
}
struct st
{
int ls, rs;
ll siz;
} tr[N << 5];
int n, m, idx, a[N], rt[N], tcnt = 1;
void build(int &x, int l, int r)
{
x = ++idx;
if (l == r)
{
tr[x].siz = a[l];
return;
}
int mid = (l + r) >> 1;
build(tr[x].ls, l, mid);
build(tr[x].rs, mid + 1, r);
tr[x].siz = tr[tr[x].ls].siz + tr[tr[x].rs].siz;
}
void split(int &x, int &pre, int l, int r, int lb, int rb)
{
if (l >= lb and r <= rb)
{
swap(x, pre);
return;
}
if (!x)
x = ++idx;
int mid = (l + r) >> 1;
if (lb <= mid)
split(tr[x].ls, tr[pre].ls, l, mid, lb, rb);
if (rb > mid)
split(tr[x].rs, tr[pre].rs, mid + 1, r, lb, rb);
tr[x].siz = tr[tr[x].ls].siz + tr[tr[x].rs].siz;
tr[pre].siz = tr[tr[pre].ls].siz + tr[tr[pre].rs].siz;
}
void merge(int &x, int &pre)
{
if (!pre)
return;
if (!x)
{
x = pre;
return;
}
tr[x].siz += tr[pre].siz;
merge(tr[x].ls, tr[pre].ls);
merge(tr[x].rs, tr[pre].rs);
}
void update(int &x, int l, int r, int tar, int v)
{
if (!x)
x = ++idx;
if (l == r)
{
tr[x].siz += v;
return;
}
int mid = (l + r) >> 1;
if (tar <= mid)
update(tr[x].ls, l, mid, tar, v);
else
update(tr[x].rs, mid + 1, r, tar, v);
tr[x].siz = tr[tr[x].ls].siz + tr[tr[x].rs].siz;
}
ll query(int &x, int l, int r, int lb, int rb)
{
if (!x)
return 0;
if (l >= lb and r <= rb)
return tr[x].siz;
int mid = (l + r) >> 1;
ll res = 0;
if (lb <= mid)
res += query(tr[x].ls, l, mid, lb, rb);
if (rb > mid)
res += query(tr[x].rs, mid + 1, r, lb, rb);
return res;
}
int kth(int &x, int l, int r, int v)
{
if (l == r)
return l;
int mid = (l + r) >> 1;
if (v > tr[tr[x].ls].siz)
return kth(tr[x].rs, mid + 1, r, v - tr[tr[x].ls].siz);
return kth(tr[x].ls, l, mid, v);
}
int main()
{
read(n, m);
for (int i = 1; i <= n; i++)
{
read(a[i]);
}
build(rt[tcnt], 1, n);
for (int i = 1, op, x, y, z; i <= m; i++)
{
read(op, x, y);
if (op == 1)
{
merge(rt[x], rt[y]);
continue;
}
if (op == 4)
{
print(tr[rt[x]].siz < y ? -1 : kth(rt[x], 1, n, y));
putchar('\n');
continue;
}
read(z);
if (!op)
{
split(rt[++tcnt], rt[x], 1, n, y, z);
continue;
}
if (op == 2)
{
update(rt[x], 1, n, z, y);
continue;
}
print(query(rt[x], 1, n, y, z));
putchar('\n');
}
}