线段树也能是 Trie 树 题解

题意简述

给定一个长为 \(n = 2^k\) 的序列 \(\{a_0, \ldots, a_{n - 1}\}\),你需要使用数据结构维护它,支持 \(m\) 次以下操作:

  1. 单点加:\(a_x \gets a_x + y\)
  2. 区间查:\(\sum \limits _ {i = l} ^ r a_i\)
  3. 全局下标与:\(a'_{i \operatorname{and} x} \gets a_{i}\),即把 \(a_i\) 累加到新的 \(a\) 的第 \(i \operatorname{and} x\) 位上;
  4. 全局下标或:\(a'_{i \operatorname{or} x} \gets a_{i}\)
  5. 全局下标异或:\(a'_{i \operatorname{xor} x} \gets a_{i}\)

\(k \leq 19\)\(m \leq 2^{19}\)

题目分析

数据结构题,发现 \(n = 2^k\),于是想到线段树,发现此时线段树也是一棵对于下标的 Trie 树,那么我们只需要思考对于下标的位运算怎么处理。

按位考虑,套路化地发现,\(\operatorname{xor}\) 的本质是将特定的某些位翻转,即如果 \(x\) 的第 \(i\) 位上为 \(1\),那么线段树上,从叶子往根数第 \(i\) 层(叶子为第 \(-1\) 层)的左右儿子交换;\(\operatorname{and}\) 的本质是将某些特定的位设为 \(0\),即若 \(x\)\(i\) 位为 \(0\),那么线段树上第 \(i\) 层的右儿子合并到左儿子上,合并为线段树合并;类似 \(\operatorname{or}\) 的本质是将某些特定的位设为 \(1\),即若 \(x\)\(i\) 位为 \(1\),那么线段树上第 \(i\) 层的左儿子合并到右儿子上,合并为线段树合并。

我们先来考虑线段树合并的正确性。我们知道,线段树合并的时间复杂度约等于两棵线段树重合的结点数,约等于较小的线段树结点数。由于合并后也会减小这么多点,所以整个过程下来,势能分析知,复杂度为整个过程线段树结点数。除了初始 \(n\) 和节点外,还有操作 \(1\) 带来的 \(\mathcal{O}(m \log n)\) 个结点,其他操作不会带来新的节点,所以总的时间复杂度为 \(\mathcal{O}(n + m \log n)\),是正确的。

但是显然不可以每次暴力对那么多结点依次进行合并操作、交换操作,所以需要打懒惰标记。标记在层与层之间是独立的。我们考虑某一层的操作时间轴,由三种操作构成。如果出现了一次合并操作,那么之前的全部操作都会无效。所以,总是可以将时间轴简化成等价的一次合并后跟着若干次交换子树。这样标记也就好打了。

时间复杂度:\(\mathcal{O}(n + m \log n)\)

代码

#include <cstdio>
#include <iostream>
using namespace std;

const int MAX = 1 << 27;
char buf[MAX], *p = buf, obuf[MAX], *op = obuf;
#ifdef XuYueming
# define fread(_, __, ___, ____)
#else
# define getchar() *p++
#endif
#define putchar(x) *op++ = x
#define isdigit(x) ('0' <= x && x <= '9')
#define __yzh__(x) for (; x isdigit(ch); ch = getchar())
template <typename T>
inline void read(T &x) {
    x = 0; char ch = getchar(); __yzh__(!);
    __yzh__( ) x = (x << 3) + (x << 1) + (ch ^ 48);
}
template <typename T>
inline void write(T x) {
    static short stack[20], top(0);
    do stack[++top] = x % 10; while (x /= 10);
    while (top) putchar(stack[top--] | 48);
    putchar('\n');
}

const int K = 20, N = 1 << K | 10;

using lint = long long;

int n, m, k;
int tim[K], ctim[K];  // 总时间,最后一次左子树合并到右子树的时间

namespace Segment_Tree {
    int root, tot;
    
    struct node {
        int ls, rs;
        int t;
        lint sum;
    } tree[N * 2 + N * K];
    
    #define Ls tree[idx].ls
    #define Rs tree[idx].rs
    
    int combine(int, int, int);
    
    inline void pushdown(int idx, int dpt) {
        if (tree[idx].t < ctim[dpt])
            Rs = combine(Ls, Rs, dpt - 1), Ls = 0, tree[idx].t = ctim[dpt];
        if ((tim[dpt] - tree[idx].t) & 1)
            swap(Ls, Rs);
        tree[idx].t = tim[dpt];
    }
    
    int combine(int u, int v, int dpt) {
        if (!u || !v) return u | v;
        tree[u].sum += tree[v].sum;
        if (!~dpt) return u;
        pushdown(u, dpt), pushdown(v, dpt);
        tree[u].ls = combine(tree[u].ls, tree[v].ls, dpt - 1);
        tree[u].rs = combine(tree[u].rs, tree[v].rs, dpt - 1);
        return u;
    }
    
    void modify(int &idx, int trl, int trr, int dpt, int p, int val) {
        if (p > trr || p < trl) return;
        if (!idx) idx = ++tot, ~dpt && (tree[idx].t = tim[dpt]);
        tree[idx].sum += val;
        if (!~dpt) return;
        int mid = (trl + trr) >> 1;
        pushdown(idx, dpt);
        modify(Ls, trl, mid, dpt - 1, p, val);
        modify(Rs, mid + 1, trr, dpt - 1, p, val);
    }
    
    lint query(int idx, int trl, int trr, int dpt, int l, int r) {
        if (trl > r || trr < l || !idx) return 0;
        if (l <= trl && trr <= r) return tree[idx].sum;
        pushdown(idx, dpt);
        int mid = (trl + trr) >> 1;
        return query(Ls, trl, mid, dpt - 1, l, r) + query(Rs, mid + 1, trr, dpt - 1, l, r);
    }
}

using Segment_Tree::root;
using Segment_Tree::modify;
using Segment_Tree::query;

signed main() {
    #ifndef XuYueming
    freopen("jeden.in", "r", stdin);
    freopen("jeden.out", "w", stdout);
    #endif
    fread(buf, 1, MAX, stdin);
    read(n), read(m), k = __lg(n);
    for (int i = 0, x; i < n; ++i) read(x), modify(root, 0, n - 1, k - 1, i, x);
    for (int x, y; m--; ) {
        char op; do op = getchar(); while (op < 'a' || op > 'z');
        op = getchar();
        if (op == 'd')
            read(x), read(y), modify(root, 0, n - 1, k - 1, x, y);
        else if (op == 'u')
            read(x), read(y), write(query(root, 0, n - 1, k - 1, x, y));
        else if (op == 'n') {
            read(x);
            for (int i = 0; i < k; ++i)
                if (!(x & 1 << i))
                    ctim[i] = ++tim[i], ++tim[i];
        } else if (op == 'r') {
            read(x);
            for (int i = 0; i < k; ++i)
                if (x & 1 << i)
                    ctim[i] = ++tim[i];
        } else {
            read(x);
            for (int i = 0; i < k; ++i)
                if (x & 1 << i)
                    ++tim[i];
        }
    }
    fwrite(obuf, 1, op - obuf, stdout);
    return 0;
}
posted @ 2024-11-02 17:05  XuYueming  阅读(7)  评论(0编辑  收藏  举报