树状数组(待补)(生硬 公式 用法 证明)
理解之后挺简单的。开始比较抽象,理解后比较简单,不如直接看代码。
树状数组是一个兼顾修改和查询的数据结构。一般可以支持,\(O(\log n)\) 的区间查询和单点修改。如果原数组为差分数组,可以实现区间修改单点查询。
因为常数小,在某些情况下比较好用,但一般它能做到的,线段树都能做到。
中心思想
关于树状数组的一切因为涉及 lowbit 比较抽象。
核心就是 \(tr\) 数组,\(tr[i]\) 数组代表原数组 \(a\) 以 \(i\) 为结尾长度为 \(\operatorname {lowbit}(i)\) 的下标区间和,也就是
\(a[i - \operatorname {lowbit}(i) + 1, i]\)。
我们实质上是把这个数组拆分成了一个树,但计算原理运行上和树没什么关系。只是呈树形关系。关于原理见这篇文章 让你顿悟树状数组原理与由来 - 知乎 (zhihu.com)
查询
我们知道一个数有二进制的形式,如 \(11 = 1011\)。如果我们求原数组下标 \([1, 11]\) 的和,可以把这个区间拆分,把每个拆分区间提前算出来,相当于预处理的方式,减少时间复杂度。
树状数组就是按照 tr 数组的形式来拆分了区间,像 \([1, 11]\) 就可以用(二进制下)\(tr[1011],tr[1010],tr[1000]\) 这三个给凑出来,比起原来相加了 11 次,这只加了 3 次,大大优化了时间。
对于一个区间可以不断先加上 \(tr[\operatorname {lowbit}(i)]\),再把 i 减去 lowbit(i),这样凑出。如 11 = 1011,有以下关系:
而 \(tr[1011],tr[1010],tr[1000]\) 恰好为 \(tr[11],tr[11- \operatorname {lowbit}(11)] =tr[10],tr[10 -\operatorname {lowbit}(10)] = tr[8]\) 恰好符合上面的运行过程。
如果想求任意区间,方法和前缀和求区间方式一样。
int sum(int x) // 求出为第1 - x的和 sum/query
{
int res = 0;
for (int i = x; i; i -= lowbit(i)) res += tr[i];
return res;
}
关系
\(tr\) 数组之间有关系,即一个数不断相加当前的 lowbit 就可以找到所有包含它的 tr 数组,即它可以影响的数组,即父节点;同理通过不断减当前的 lowbit 就可以所有找到影响它的 tr 数组,即子节点。这个关系也就是称为树状数组的原因。
证明待补。自己想想也简单。
修改
如果第 i 个数加上了 k,那么所有包含 i 的 tr 都应该加上 k,可写出。
void add(int x, int k) // 让第i个数加上k
{
for (int i = x; i <= n; i += lowbit(i)) tr[i] += k;
}
求出 tr 数组
我们求出的方式以 \(tr[10100]\) 举例
tr[10100] = [10001,10100]
10011 = 10100 - 1
10010 = 10011 - lowbit(10011) = 10011 - 1
10010 - lowbit(10010) = 10010 - 10 = 10000 不属于 [10001,10100]
tr[10011] = [10011,10011]
tr[10010] = [10001,10010]
tr[10100] =
a[10100] + tr[10011] + tr[10010]
可以发现就是在不超过 \([10100 - lowbit(10100) + 1, 10100]\) 范围内,原数组的本身加上它本身 - 1 不断减 lowbit 的 tr 之和。
代码如此
for (int i = 1; i <= n; i ++ ) tr[i] = a[i]; // 先赋值a[i]
for (int x = 1; x <= n; x ++ )
for (int i = x - 1; i >= x - lowbit(x) + 1; i -= lowbit(i)) tr[x] += tr[i];
代码
初始化
// 初始化
int a[], tr[], sum[]; // a[]原数组 tr[]树状数组 sum[]前缀和
int lowbit(int x)
{
return x & -x;
}
void add(int x, int k)
{
for (int i = x; i <= n; i + lowbit(i)) tr[i] += k;
}
// (1) O(nlogn) 常用 使用修改的方式
for (int i = 1; i <= n; i ++ ) add(i, a[i]);
// (2) O(n) 根据边
for (int i = 1; i <= n; i ++ ) tr[i] = a[i]; // 先赋值a[i]
for (int x = 1; x <= n; x ++ )
for (int i = x - 1; i >= x - lowbit(x) + 1; i -= lowbit(i)) tr[x] += tr[i];
// (3) O(n) 根据原理
for (int i = 1; i <= n; i ++ ) sum[i] = sum[i - 1] + a[i];
for (int i = 1; i <= n; i ++ ) tr[i] = sum[i] - sum[i - lowbit(i)];
修改/查询
void add(int x, int k) // 让第x个数加上k
{
for (int i = x; i <= n; i += lowbit(i)) tr[i] += k;
}
int sum(int x) // 求出为第1 - x的和
{
int res = 0;
for (int i = x; i; i -= lowbit(i)) res += tr[i];
return res;
}
应用
特殊的应用,比如支持区间修改 + 区间查询。这个的常数是对应线段树的四分之一。
![[Pasted image 20241125081846.png|303]]
差分 + 树状数组可以把区间修改优化到O(logn)
区间和 \(= a[i] + a[i + 1] + … + a[x] = b[1 … i] + b[1 … i + 1] + … b[1 … x]\) 最后可以把区间求和优化到 \(O(logn)\)。
如这题 P11217 【MX-S4-T1】「yyOI R2」youyou 的垃圾桶 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 如果使用线段树的话 \(O(q {\log ^2 n})\) 一定过不了,而用树状数组就可以。线段树需要 \(O(q\log n)\) 并且常数要小。
代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL ;
const int N = 100010;
int n, m;
int a[N];
LL tr1[N]; // b[i] 的前缀和
LL tr2[N]; // b[i] * i 的前缀和
int lowbit(int x)
{
return x & -x;
}
void add(LL tr[], int x, LL k)
{
for (int i = x; i <= n; i += lowbit(i)) tr[i] += k;
}
LL sum(LL tr[], int x)
{
LL res = 0;
for (int i = x; i; i -= lowbit(i)) res += tr[i];
return res;
}
LL p_sum(int x)
{
return sum(tr1, x) * (x + 1) - sum(tr2, x);
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i ++ ) scanf("%d", &a[i]);
for (int i = 1; i <= n; i ++ )
{
int b = a[i] - a[i - 1];
add(tr1, i, b);
add(tr2, i, (LL)b * i);
}
while (m -- )
{
int l, r, d;
char op[2];
scanf("%s%d%d", op, &l, &r);
if (op[0] == 'C')
{
scanf("%d", &d);
add(tr1, l, d), add(tr2, l, l * d);
add(tr1, r + 1, -d), add(tr2, r + 1, (r + 1) * -d);
}
else printf("%lld\n", p_sum(r) - p_sum(l - 1));
}
return 0;
}