树状数组入门

简介

在许多情况下,我们都要维护一个数组的前缀和 \(s[i]=a[1]+a[2]+\cdots+a[i]\) ,但如果我们修改了其中一个 \(a[i]\) ,那么 \(s[i],s[i+1],\cdots,s[n]\) 都会发生变化。所以每次修改后我们都要对前缀和数组进行维护,在最坏的情况下复杂度为 \(O(n)\) 。而树状数组能很好的解决这一问题,它的修改与求和操作复杂度都是 \(O(\log n)\)

原理

我们设原数组为 \(a\) ,树状数组为 \(c\) ,对于一个长度为 \(8\) 的原数组,下面这张图(from oi-wiki)展示了树状数组的原理:

  • \(c[1]=a[1]\)
  • \(c[2]=a[1]+a[2]\)
  • \(c[3]=a[3]\)
  • \(c[4]=a[1]+a[2]+a[3]+a[4]\)
  • \(c[5]=a[5]\)
  • \(c[6]=a[5]+a[6]\)
  • \(c[7]=a[7]\)
  • \(c[8]=a[1]+a[2]+\cdots +a[8]\)

我们用 \(f(i)\) 表示 \(c[i]\) 存储了多少个连续的数组 \(a\) 中的元素,即:

\[c[i]=a[i]+a[i-1]+\cdots+a[i-f(i)+1] \]

为了求出 \(f(i)\) ,我们列出一张表并观察规律:

\(i\) 的十进制 1 2 3 4 5 6 7 8
\(i\) 的二进制 0001 0010 0011 0100 0101 0110 0111 1000
\(f(i)\) 0001 0010 0001 0100 0001 0010 0001 1000

通过观察,可以得出 \(f(i)\) 的值就是 \(i\) 的二进制表示中从最低位到高位出现的第一个 \(1\) 和它之前所有的 \(0\) 组成的二进制数的值,换一种说法,如果 \(i\) 的二进制表示中从最低位到高位有 \(k\) 个连续的 \(0\) ,第 \(k+1\) 位是 \(1\),那么 \(f(i)=2^k\) ,在树状数组中, \(f\) 有一个专门的名称 —— \(lowbit\)

那么如何求出 \(lowbit(i)\) 呢?设 \(i\) 的二进制表示从最低位到高位有 \(k\) 个连续的 \(0\) ,第 \(k+1\) 位是 \(1\) 。对 \(i\) 取反,这些 \(0\) 就都变成了 \(1\) ,第 \(k+1\) 位就变成了 \(0\) ,再加 \(1\) ,可以发现前 \(k+1\) 位还原到了最初的值,而更高位的二进制位则是之前的反码,此时对 \(i\) 和取反后加 \(1\)\(i\) 进行与运算,就能得到答案了。 根据补码的原理, \(\sim i+1\) 等价于 \(-i\) ,所以 \(lowbit(i)=i\&(-i)\)

int lowbit(int x)
{
    return x & (-x);
}

接下来,我们考虑,如果 \(a[i]\) 的值改变了,如何更新 \(c\) ,我们可以发现, 当 \(i\) 的二进制表示从最低位到高位有 \(k\) 个连续的 \(0\) ,第 \(k+1\) 位是 \(1\) 时, \(i+lowbit(i)\) 会使第 \(k+1\) 位变成 \(0\) 且前 \(k\) 位依然为 \(0\) ,所以 \(lowbit(i+lowbit(i))\) 一定大于 \(lowbit(i)\) ,那么 \(c[i+lowbit(i)]\) 一定包含了 \(a[i]\) 的值,需要更新

void update(int x, int k)
{
    while(x <= n) {	
        c[x] += k;
        x += lowbit(x);
    }
}

最后考虑如何求和,根据 \(c\) 的定义可知, \(sum[i]=c[i]+c[i-lowbit[i]]+\cdots+c[1]\) ,求和的方法就很显然了

long long get_sum(int x)
{
    long long res = 0;
    while(x >= 1) {
        res += c[x];
	x -= lowbit(x);
    }
    return res;
}

一个最基本的树状数组模板就实现了

变式

上述树状数组模板是最基本的前缀和模板,支持单点更新(更新原数组某个元素的值),区间查询(查询原数组某个区间的和

LOJ 树状数组模板1

#include<bits/stdc++.h>
#define ll long long
using namespace std;

const int MAX_N = 1000000 + 5;
int n, q;
ll c[MAX_N];

int lowbit(int x)
{
    return x & (-x);
}

void update(int x, int k)
{
    while(x <= n) {
        c[x] += k;
        x += lowbit(x);
    }
}

ll get_sum(int x)
{
    ll res = 0;
    while(x >= 1) {
        res += c[x];
        x -= lowbit(x);
    }
    return res;
}

int main()
{
    scanf("%d%d", &n, &q);
    for(int i = 1; i <= n; i++) {
        int a;
        scanf("%d", &a);
        update(i, a);
    }
    for(int i = 1; i <= q; i++) {
        int opt, x, y;
        scanf("%d%d%d", &opt, &x, &y);
        if(opt == 1)
            update(x, y);
        else
            printf("%lld\n", get_sum(y) - get_sum(x - 1));
    }
    return 0;
}

在此基础上稍加变通,我们就可以让树状数组实现区间更新(对原数组某个区间的所有元素都加上一个值)\单点查询(查询原数组某元素的值)和区间更新\区间查询

区间更新\单点查询

考虑如何将求前缀和的问题转化为求某个元素的值的问题,可以发现,差分数组的前缀和就是某个元素的值,且在进行区间更新时,只需要修改区间两端的差分数组的值,所以我们用树状数组维护原数组的差分数组就能实现区间更新\单点查询的功能了

LOJ 树状数组模板2

#include<bits/stdc++.h>
#define ll long long
using namespace std;

const int MAX_N = 1000000 + 5;
int a[MAX_N];
ll c[MAX_N];
int n, q;

int lowbit(int x)
{
    return x & (-x);
}

void update(int x, int k)
{
    while(x <= n) {
        c[x] += k;
        x += lowbit(x);
    }
}

ll get_sum(int x)
{
    ll res = 0;
    while(x >= 1) {
        res += c[x];
        x -= lowbit(x);
    }
    return res;
}

int main()
{
    scanf("%d%d", &n, &q);
    for(int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
        update(i, a[i] - a[i - 1]);
    }
    for(int i = 1; i <= q; i++) {
        int opt, x, y, z;
        scanf("%d", &opt);
        if(opt == 1) {
            scanf("%d%d%d", &x, &y, &z);
            update(x, z);
            update(y + 1, -z);
        }else {
            scanf("%d", &x);
            printf("%lld\n", get_sum(x));
        }
    }
    return 0;
}
区间更新\区间查询

我们仍然采用差分的思路:

\(sum[i]=a[i]+a[i-1]+\cdots+a[1]=(d[1]+d[2]+\cdots+d[i])+(d[1]+\cdots+d[i-1])+\cdots+d[1]\)

进一步分解:

\(sum[i]=i(d[1]+d[2]+\cdots+d[i])-(1-1)d[1]-(2-1)d[2]-\cdots-(i-1)d[i]\)

所以只要用两个树状数组,一个维护 \(d[i]\) ,另一个维护 \((i-1)d[i]\) 即可

LOJ 树状数组模板3

#include<bits/stdc++.h>
#define ll long long
using namespace std;

const int MAX_N = 1000000 + 5;
int a[MAX_N];
ll sum1[MAX_N], sum2[MAX_N];
int n, q;

int lowbit(int x)
{
    return x & (-x);
}

void update(int x, int k)
{
    int t = x;
    while(t <= n) {
        sum1[t] += k;
        sum2[t] += 1ll * (x - 1) * k;
        t += lowbit(t);
    }
}

ll get_sum(int x)
{
    ll res = 0;
    int t = x;
    while(t >= 1) {
        res = res + x * sum1[t] - sum2[t];
        t -= lowbit(t);
    }
    return res;
}

int main()
{
    scanf("%d%d", &n, &q);
    for(int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
        update(i, a[i] - a[i - 1]);
    }
    for(int i = 1; i <= q; i++) {
        int opt, l, r, x;
        scanf("%d", &opt);
        if(opt == 1) {
            scanf("%d%d%d", &l, &r, &x);
            update(l, x);
            update(r + 1, -x);
        }else {
            scanf("%d%d", &l, &r);
            printf("%lld\n", get_sum(r) - get_sum(l - 1));
        }
    }
    return 0;
}
posted @ 2021-12-14 11:16  f(k(t))  阅读(77)  评论(1编辑  收藏  举报