线段树也能是 Trie 树 题解
题意简述
给定一个长为 \(n = 2^k\) 的序列 \(\{a_0, \ldots, a_{n - 1}\}\),你需要使用数据结构维护它,支持 \(m\) 次以下操作:
- 单点加:\(a_x \gets a_x + y\);
- 区间查:\(\sum \limits _ {i = l} ^ r a_i\);
- 全局下标与:\(a'_{i \operatorname{and} x} \gets a_{i}\),即把 \(a_i\) 累加到新的 \(a\) 的第 \(i \operatorname{and} x\) 位上;
- 全局下标或:\(a'_{i \operatorname{or} x} \gets a_{i}\);
- 全局下标异或:\(a'_{i \operatorname{xor} x} \gets a_{i}\);
\(k \leq 19\),\(m \leq 2^{19}\)。
题目分析
数据结构题,发现 \(n = 2^k\),于是想到线段树,发现此时线段树也是一棵对于下标的 Trie 树,那么我们只需要思考对于下标的位运算怎么处理。
按位考虑,套路化地发现,\(\operatorname{xor}\) 的本质是将特定的某些位翻转,即如果 \(x\) 的第 \(i\) 位上为 \(1\),那么线段树上,从叶子往根数第 \(i\) 层(叶子为第 \(-1\) 层)的左右儿子交换;\(\operatorname{and}\) 的本质是将某些特定的位设为 \(0\),即若 \(x\) 第 \(i\) 位为 \(0\),那么线段树上第 \(i\) 层的右儿子合并到左儿子上,合并为线段树合并;类似 \(\operatorname{or}\) 的本质是将某些特定的位设为 \(1\),即若 \(x\) 第 \(i\) 位为 \(1\),那么线段树上第 \(i\) 层的左儿子合并到右儿子上,合并为线段树合并。
我们先来考虑线段树合并的正确性。我们知道,线段树合并的时间复杂度约等于两棵线段树重合的结点数,约等于较小的线段树结点数。由于合并后也会减小这么多点,所以整个过程下来,势能分析知,复杂度为整个过程线段树结点数。除了初始 \(n\) 和节点外,还有操作 \(1\) 带来的 \(\mathcal{O}(m \log n)\) 个结点,其他操作不会带来新的节点,所以总的时间复杂度为 \(\mathcal{O}(n + m \log n)\),是正确的。
但是显然不可以每次暴力对那么多结点依次进行合并操作、交换操作,所以需要打懒惰标记。标记在层与层之间是独立的。我们考虑某一层的操作时间轴,由三种操作构成。如果出现了一次合并操作,那么之前的全部操作都会无效。所以,总是可以将时间轴简化成等价的一次合并后跟着若干次交换子树。这样标记也就好打了。
时间复杂度:\(\mathcal{O}(n + m \log n)\)。
代码
#include <cstdio>
#include <iostream>
using namespace std;
const int MAX = 1 << 27;
char buf[MAX], *p = buf, obuf[MAX], *op = obuf;
#ifdef XuYueming
# define fread(_, __, ___, ____)
#else
# define getchar() *p++
#endif
#define putchar(x) *op++ = x
#define isdigit(x) ('0' <= x && x <= '9')
#define __yzh__(x) for (; x isdigit(ch); ch = getchar())
template <typename T>
inline void read(T &x) {
x = 0; char ch = getchar(); __yzh__(!);
__yzh__( ) x = (x << 3) + (x << 1) + (ch ^ 48);
}
template <typename T>
inline void write(T x) {
static short stack[20], top(0);
do stack[++top] = x % 10; while (x /= 10);
while (top) putchar(stack[top--] | 48);
putchar('\n');
}
const int K = 20, N = 1 << K | 10;
using lint = long long;
int n, m, k;
int tim[K], ctim[K]; // 总时间,最后一次左子树合并到右子树的时间
namespace Segment_Tree {
int root, tot;
struct node {
int ls, rs;
int t;
lint sum;
} tree[N * 2 + N * K];
#define Ls tree[idx].ls
#define Rs tree[idx].rs
int combine(int, int, int);
inline void pushdown(int idx, int dpt) {
if (tree[idx].t < ctim[dpt])
Rs = combine(Ls, Rs, dpt - 1), Ls = 0, tree[idx].t = ctim[dpt];
if ((tim[dpt] - tree[idx].t) & 1)
swap(Ls, Rs);
tree[idx].t = tim[dpt];
}
int combine(int u, int v, int dpt) {
if (!u || !v) return u | v;
tree[u].sum += tree[v].sum;
if (!~dpt) return u;
pushdown(u, dpt), pushdown(v, dpt);
tree[u].ls = combine(tree[u].ls, tree[v].ls, dpt - 1);
tree[u].rs = combine(tree[u].rs, tree[v].rs, dpt - 1);
return u;
}
void modify(int &idx, int trl, int trr, int dpt, int p, int val) {
if (p > trr || p < trl) return;
if (!idx) idx = ++tot, ~dpt && (tree[idx].t = tim[dpt]);
tree[idx].sum += val;
if (!~dpt) return;
int mid = (trl + trr) >> 1;
pushdown(idx, dpt);
modify(Ls, trl, mid, dpt - 1, p, val);
modify(Rs, mid + 1, trr, dpt - 1, p, val);
}
lint query(int idx, int trl, int trr, int dpt, int l, int r) {
if (trl > r || trr < l || !idx) return 0;
if (l <= trl && trr <= r) return tree[idx].sum;
pushdown(idx, dpt);
int mid = (trl + trr) >> 1;
return query(Ls, trl, mid, dpt - 1, l, r) + query(Rs, mid + 1, trr, dpt - 1, l, r);
}
}
using Segment_Tree::root;
using Segment_Tree::modify;
using Segment_Tree::query;
signed main() {
#ifndef XuYueming
freopen("jeden.in", "r", stdin);
freopen("jeden.out", "w", stdout);
#endif
fread(buf, 1, MAX, stdin);
read(n), read(m), k = __lg(n);
for (int i = 0, x; i < n; ++i) read(x), modify(root, 0, n - 1, k - 1, i, x);
for (int x, y; m--; ) {
char op; do op = getchar(); while (op < 'a' || op > 'z');
op = getchar();
if (op == 'd')
read(x), read(y), modify(root, 0, n - 1, k - 1, x, y);
else if (op == 'u')
read(x), read(y), write(query(root, 0, n - 1, k - 1, x, y));
else if (op == 'n') {
read(x);
for (int i = 0; i < k; ++i)
if (!(x & 1 << i))
ctim[i] = ++tim[i], ++tim[i];
} else if (op == 'r') {
read(x);
for (int i = 0; i < k; ++i)
if (x & 1 << i)
ctim[i] = ++tim[i];
} else {
read(x);
for (int i = 0; i < k; ++i)
if (x & 1 << i)
++tim[i];
}
}
fwrite(obuf, 1, op - obuf, stdout);
return 0;
}
本文作者:XuYueming,转载请注明原文链接:https://www.cnblogs.com/XuYueming/p/18521951。
若未作特殊说明,本作品采用 知识共享署名-非商业性使用 4.0 国际许可协议 进行许可。