树状数组(待补)(生硬 公式 用法 证明)

理解之后挺简单的。开始比较抽象,理解后比较简单,不如直接看代码。

树状数组是一个兼顾修改和查询的数据结构。一般可以支持,\(O(\log n)\)区间查询单点修改。如果原数组为差分数组,可以实现区间修改单点查询。

因为常数小,在某些情况下比较好用,但一般它能做到的,线段树都能做到。

中心思想

关于树状数组的一切因为涉及 lowbit 比较抽象。

核心就是 \(tr\) 数组,\(tr[i]\) 数组代表原数组 \(a\)\(i\) 为结尾长度为 \(\operatorname {lowbit}(i)\) 的下标区间和,也就是
\(a[i - \operatorname {lowbit}(i) + 1, i]\)

我们实质上是把这个数组拆分成了一个树,但计算原理运行上和树没什么关系。只是呈树形关系。关于原理见这篇文章 让你顿悟树状数组原理与由来 - 知乎 (zhihu.com)

查询

我们知道一个数有二进制的形式,如 \(11 = 1011\)。如果我们求原数组下标 \([1, 11]\) 的和,可以把这个区间拆分,把每个拆分区间提前算出来,相当于预处理的方式,减少时间复杂度。

树状数组就是按照 tr 数组的形式来拆分了区间,像 \([1, 11]\) 就可以用(二进制下)\(tr[1011],tr[1010],tr[1000]\) 这三个给凑出来,比起原来相加了 11 次,这只加了 3 次,大大优化了时间。

对于一个区间可以不断先加上 \(tr[\operatorname {lowbit}(i)]\),再把 i 减去 lowbit(i),这样凑出。如 11 = 1011,有以下关系:

\[\begin{aligned} 1011 &=[1, 1011] \\ tr[1011] &= [1011,1011] \\ tr[1010] &= [1001, 1010] \\ tr[1000] &= [1, 1000] \\ \end{aligned}\]

\(tr[1011],tr[1010],tr[1000]\) 恰好为 \(tr[11],tr[11- \operatorname {lowbit}(11)] =tr[10],tr[10 -\operatorname {lowbit}(10)] = tr[8]\) 恰好符合上面的运行过程。

如果想求任意区间,方法和前缀和求区间方式一样。

int sum(int x) // 求出为第1 - x的和 sum/query
{
    int res = 0;
    for (int i = x; i; i -= lowbit(i)) res += tr[i];
    return res;
}

关系

\(tr\) 数组之间有关系,即一个数不断相加当前的 lowbit 就可以找到所有包含它的 tr 数组,即它可以影响的数组,即父节点;同理通过不断减当前的 lowbit 就可以所有找到影响它的 tr 数组,即子节点。这个关系也就是称为树状数组的原因。

证明待补。自己想想也简单。

修改

如果第 i 个数加上了 k,那么所有包含 i 的 tr 都应该加上 k,可写出。

void add(int x, int k) // 让第i个数加上k
{
    for (int i = x; i <= n; i += lowbit(i)) tr[i] += k;
}

求出 tr 数组

我们求出的方式以 \(tr[10100]\) 举例

tr[10100] = [10001,10100]

10011 = 10100 - 1
10010 = 10011 - lowbit(10011) = 10011 - 1
10010 - lowbit(10010) = 10010 - 10 = 10000 不属于 [10001,10100]

tr[10011] = [10011,10011]
tr[10010] = [10001,10010]

tr[10100] = 
a[10100] + tr[10011] + tr[10010]

可以发现就是在不超过 \([10100 - lowbit(10100) + 1, 10100]\) 范围内,原数组的本身加上它本身 - 1 不断减 lowbit 的 tr 之和。
代码如此

for (int i = 1; i <= n; i ++ ) tr[i] = a[i]; // 先赋值a[i]

for (int x = 1; x <= n; x ++ )
    for (int i = x - 1; i >= x - lowbit(x) + 1; i -= lowbit(i)) tr[x] += tr[i];

代码

初始化

// 初始化

int a[], tr[], sum[]; // a[]原数组 tr[]树状数组 sum[]前缀和

int lowbit(int x)
{
    return x & -x;
}

void add(int x, int k)
{
    for (int i = x; i <= n; i + lowbit(i)) tr[i] += k;
}

// (1) O(nlogn) 常用 使用修改的方式

for (int i = 1; i <= n; i ++ ) add(i, a[i]); 

// (2) O(n)  根据边

for (int i = 1; i <= n; i ++ ) tr[i] = a[i]; // 先赋值a[i]

for (int x = 1; x <= n; x ++ )
    for (int i = x - 1; i >= x - lowbit(x) + 1; i -= lowbit(i)) tr[x] += tr[i];

// (3) O(n) 根据原理

for (int i = 1; i <= n; i ++ ) sum[i] = sum[i - 1] + a[i]; 

for (int i = 1; i <= n; i ++ ) tr[i] = sum[i] - sum[i - lowbit(i)];

修改/查询

void add(int x, int k) // 让第x个数加上k
{
    for (int i = x; i <= n; i += lowbit(i)) tr[i] += k;
}

int sum(int x) // 求出为第1 - x的和
{
    int res = 0;
    for (int i = x; i; i -= lowbit(i)) res += tr[i];
    return res;
}

应用

特殊的应用,比如支持区间修改 + 区间查询。这个的常数是对应线段树的四分之一。

![[Pasted image 20241125081846.png|303]]

差分 + 树状数组可以把区间修改优化到O(logn)
区间和 \(= a[i] + a[i + 1] + … + a[x] = b[1 … i] + b[1 … i + 1] + … b[1 … x]\) 最后可以把区间求和优化到 \(O(logn)\)

如这题 P11217 【MX-S4-T1】「yyOI R2」youyou 的垃圾桶 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 如果使用线段树的话 \(O(q {\log ^2 n})\) 一定过不了,而用树状数组就可以。线段树需要 \(O(q\log n)\) 并且常数要小。

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

typedef long long LL ;

const int N = 100010;

int n, m;
int a[N];
LL tr1[N]; // b[i] 的前缀和
LL tr2[N]; // b[i] * i 的前缀和

int lowbit(int x)
{
    return x & -x;
}

void add(LL tr[], int x, LL k)
{
    for (int i = x; i <= n; i += lowbit(i)) tr[i] += k;
}

LL sum(LL tr[], int x)
{
    LL res = 0;
    for (int i = x; i; i -= lowbit(i)) res += tr[i];
    return res;
}

LL p_sum(int x)
{
    return sum(tr1, x) * (x + 1) - sum(tr2, x); 
}

int main()
{
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i ++ ) scanf("%d", &a[i]);

    for (int i = 1; i <= n; i ++ )
    {
        int b = a[i] - a[i - 1];
        add(tr1, i, b);
        add(tr2, i, (LL)b * i);
    }

    while (m -- )
    {
        int l, r, d;
        char op[2];
        scanf("%s%d%d", op, &l, &r);
        if (op[0] == 'C')
        {
            scanf("%d", &d);
            add(tr1, l, d), add(tr2, l, l * d);
            add(tr1, r + 1, -d), add(tr2, r + 1, (r + 1) * -d); 
        }
        else printf("%lld\n", p_sum(r) - p_sum(l - 1));   
    }


    return 0;
}
posted @ 2024-11-25 08:20  blind5883  阅读(5)  评论(0编辑  收藏  举报