树状数组入门
简介
在许多情况下,我们都要维护一个数组的前缀和 \(s[i]=a[1]+a[2]+\cdots+a[i]\) ,但如果我们修改了其中一个 \(a[i]\) ,那么 \(s[i],s[i+1],\cdots,s[n]\) 都会发生变化。所以每次修改后我们都要对前缀和数组进行维护,在最坏的情况下复杂度为 \(O(n)\) 。而树状数组能很好的解决这一问题,它的修改与求和操作复杂度都是 \(O(\log n)\)
原理
我们设原数组为 \(a\) ,树状数组为 \(c\) ,对于一个长度为 \(8\) 的原数组,下面这张图(from oi-wiki)展示了树状数组的原理:
- \(c[1]=a[1]\)
- \(c[2]=a[1]+a[2]\)
- \(c[3]=a[3]\)
- \(c[4]=a[1]+a[2]+a[3]+a[4]\)
- \(c[5]=a[5]\)
- \(c[6]=a[5]+a[6]\)
- \(c[7]=a[7]\)
- \(c[8]=a[1]+a[2]+\cdots +a[8]\)
我们用 \(f(i)\) 表示 \(c[i]\) 存储了多少个连续的数组 \(a\) 中的元素,即:
为了求出 \(f(i)\) ,我们列出一张表并观察规律:
\(i\) 的十进制 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
---|---|---|---|---|---|---|---|---|
\(i\) 的二进制 | 0001 | 0010 | 0011 | 0100 | 0101 | 0110 | 0111 | 1000 |
\(f(i)\) | 0001 | 0010 | 0001 | 0100 | 0001 | 0010 | 0001 | 1000 |
通过观察,可以得出 \(f(i)\) 的值就是 \(i\) 的二进制表示中从最低位到高位出现的第一个 \(1\) 和它之前所有的 \(0\) 组成的二进制数的值,换一种说法,如果 \(i\) 的二进制表示中从最低位到高位有 \(k\) 个连续的 \(0\) ,第 \(k+1\) 位是 \(1\),那么 \(f(i)=2^k\) ,在树状数组中, \(f\) 有一个专门的名称 —— \(lowbit\)
那么如何求出 \(lowbit(i)\) 呢?设 \(i\) 的二进制表示从最低位到高位有 \(k\) 个连续的 \(0\) ,第 \(k+1\) 位是 \(1\) 。对 \(i\) 取反,这些 \(0\) 就都变成了 \(1\) ,第 \(k+1\) 位就变成了 \(0\) ,再加 \(1\) ,可以发现前 \(k+1\) 位还原到了最初的值,而更高位的二进制位则是之前的反码,此时对 \(i\) 和取反后加 \(1\) 的 \(i\) 进行与运算,就能得到答案了。 根据补码的原理, \(\sim i+1\) 等价于 \(-i\) ,所以 \(lowbit(i)=i\&(-i)\)
int lowbit(int x)
{
return x & (-x);
}
接下来,我们考虑,如果 \(a[i]\) 的值改变了,如何更新 \(c\) ,我们可以发现, 当 \(i\) 的二进制表示从最低位到高位有 \(k\) 个连续的 \(0\) ,第 \(k+1\) 位是 \(1\) 时, \(i+lowbit(i)\) 会使第 \(k+1\) 位变成 \(0\) 且前 \(k\) 位依然为 \(0\) ,所以 \(lowbit(i+lowbit(i))\) 一定大于 \(lowbit(i)\) ,那么 \(c[i+lowbit(i)]\) 一定包含了 \(a[i]\) 的值,需要更新
void update(int x, int k)
{
while(x <= n) {
c[x] += k;
x += lowbit(x);
}
}
最后考虑如何求和,根据 \(c\) 的定义可知, \(sum[i]=c[i]+c[i-lowbit[i]]+\cdots+c[1]\) ,求和的方法就很显然了
long long get_sum(int x)
{
long long res = 0;
while(x >= 1) {
res += c[x];
x -= lowbit(x);
}
return res;
}
一个最基本的树状数组模板就实现了
变式
上述树状数组模板是最基本的前缀和模板,支持单点更新(更新原数组某个元素的值),区间查询(查询原数组某个区间的和
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int MAX_N = 1000000 + 5;
int n, q;
ll c[MAX_N];
int lowbit(int x)
{
return x & (-x);
}
void update(int x, int k)
{
while(x <= n) {
c[x] += k;
x += lowbit(x);
}
}
ll get_sum(int x)
{
ll res = 0;
while(x >= 1) {
res += c[x];
x -= lowbit(x);
}
return res;
}
int main()
{
scanf("%d%d", &n, &q);
for(int i = 1; i <= n; i++) {
int a;
scanf("%d", &a);
update(i, a);
}
for(int i = 1; i <= q; i++) {
int opt, x, y;
scanf("%d%d%d", &opt, &x, &y);
if(opt == 1)
update(x, y);
else
printf("%lld\n", get_sum(y) - get_sum(x - 1));
}
return 0;
}
在此基础上稍加变通,我们就可以让树状数组实现区间更新(对原数组某个区间的所有元素都加上一个值)\单点查询(查询原数组某元素的值)和区间更新\区间查询
区间更新\单点查询
考虑如何将求前缀和的问题转化为求某个元素的值的问题,可以发现,差分数组的前缀和就是某个元素的值,且在进行区间更新时,只需要修改区间两端的差分数组的值,所以我们用树状数组维护原数组的差分数组就能实现区间更新\单点查询的功能了
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int MAX_N = 1000000 + 5;
int a[MAX_N];
ll c[MAX_N];
int n, q;
int lowbit(int x)
{
return x & (-x);
}
void update(int x, int k)
{
while(x <= n) {
c[x] += k;
x += lowbit(x);
}
}
ll get_sum(int x)
{
ll res = 0;
while(x >= 1) {
res += c[x];
x -= lowbit(x);
}
return res;
}
int main()
{
scanf("%d%d", &n, &q);
for(int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
update(i, a[i] - a[i - 1]);
}
for(int i = 1; i <= q; i++) {
int opt, x, y, z;
scanf("%d", &opt);
if(opt == 1) {
scanf("%d%d%d", &x, &y, &z);
update(x, z);
update(y + 1, -z);
}else {
scanf("%d", &x);
printf("%lld\n", get_sum(x));
}
}
return 0;
}
区间更新\区间查询
我们仍然采用差分的思路:
\(sum[i]=a[i]+a[i-1]+\cdots+a[1]=(d[1]+d[2]+\cdots+d[i])+(d[1]+\cdots+d[i-1])+\cdots+d[1]\)
进一步分解:
\(sum[i]=i(d[1]+d[2]+\cdots+d[i])-(1-1)d[1]-(2-1)d[2]-\cdots-(i-1)d[i]\)
所以只要用两个树状数组,一个维护 \(d[i]\) ,另一个维护 \((i-1)d[i]\) 即可
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int MAX_N = 1000000 + 5;
int a[MAX_N];
ll sum1[MAX_N], sum2[MAX_N];
int n, q;
int lowbit(int x)
{
return x & (-x);
}
void update(int x, int k)
{
int t = x;
while(t <= n) {
sum1[t] += k;
sum2[t] += 1ll * (x - 1) * k;
t += lowbit(t);
}
}
ll get_sum(int x)
{
ll res = 0;
int t = x;
while(t >= 1) {
res = res + x * sum1[t] - sum2[t];
t -= lowbit(t);
}
return res;
}
int main()
{
scanf("%d%d", &n, &q);
for(int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
update(i, a[i] - a[i - 1]);
}
for(int i = 1; i <= q; i++) {
int opt, l, r, x;
scanf("%d", &opt);
if(opt == 1) {
scanf("%d%d%d", &l, &r, &x);
update(l, x);
update(r + 1, -x);
}else {
scanf("%d%d", &l, &r);
printf("%lld\n", get_sum(r) - get_sum(l - 1));
}
}
return 0;
}