树状数组

1. 什么是树状数组?

树状数组是一个查询和修改复杂度都为 \(\operatorname{O}(\log n)\) 的数据结构。

看到这句话是不是想到了线段树?

是的!

但是,凡是可以使用树状数组解决的问题, 使用线段树一定可以解决, 但是线段树能够解决的问题树状数组未必能够解决。

哦,那还是用线段树吧……

然鹅:

  • 线段树的数组需要开 \(4\) 倍,树状数组只用 \(1\) 倍。

  • 树状数组代码量少,线段树写 \(1\) 题用树状数组可以写 \(2\) 题。

  • 真不错!

2. \(lowbit\)

\(\operatorname{lowbit(x)}\) 就是求 \(x\) 最低位的 \(1\)

\(c\) 数组记录原数组的和:

\(c[1]=a[1]\to c[(1)_2]=a[(1)_2]\)

\(c[2]=a[2]+a[1]\to c[(10)_2]=a[(10)_2]+a[(1)_2]\)

\(c[3]=a[3]\to c[(11)_2]=a[(11)_2]\)

\(c[4]=a[4]+a[3]+a[2]+a[1]\to c[(100)_2]=a[(100)_2]+a[(11)_2]+a[(10)_2]+a[(1)_2]\)

\(c[5]=a[5]\to c[(101)_2]=a[(101)_2]\)

\(c[6]=a[6]+a[5]\to c[(110)_2]=a[(110)_2]+a[(101)_2]\)

\(c[7]=a[7]\to c[(111)_2]=a[(111)_2]\)

更新:

\(a[1]\) 为例,要更新 \(a[1],a[2],a[4]\)

是每次加了二进制的低位 \(1\)

\(1\to1+1=10\to10+10=100\to100+100=1000>111\)(停止)

查询:

查询 \(\sum\limits_{i=1}^7{a[i]}\to\) 查询 \(c[7]+c[6]+c[4]\)

是每次去掉了二进制中的低位 \(1\)

\(111\to111-1=110\to110-10=100\to100-100=0\)(停止)

\(lowbit\) 的实现 :

int lowbit(int x)
{
	return x & -x;
}

我们知道,一个数的负数就等于对这个数取反 \(+1\)

以二进制数 \(11010\) 为例:\(11010\) 取反为 \(00101\),加 \(1\) 后为 \(00110\),两者相与便是最低位的 \(1\)

3. 一维树状数组

1. 单点修改,区间查询

P3374 【模板】树状数组 1

  • 单点修改

由刚才所说可知:

void update(int x, int k) //x 为要更新的数,k 为要加上的值,n 为数组大小
{
	for (int i = x; i <= n; i += lowbit(i))
	{
		c[i] += k;
	}
}
  • 区间查询

其实也就是单点更新的逆操作。

要想求出 \(x\sim y\) 相当于 \(1\sim y\) 减去 \(1\sim x-1\)

int query(int x)
{
	int res = 0;
	for (int i = x; i; i -= lowbit(i))
	{
		res += c[i];
	}
	return res;
}

int range_query(int x, int y)
{
	return query(y) - query(x - 1);
}

  • \(AC\;Code\)
#include <iostream>
#include <cstdio>
using namespace std;

const int MAXN = 5e5 + 5;

int n, m, op, x, y;
int c[MAXN];

int lowbit(int x)
{
	return x & -x;
}

void update(int x, int k)
{
	for (int i = x; i <= n; i += lowbit(i))
	{
		c[i] += k;
	}
}

int query(int x)
{
	int res = 0;
	for (int i = x; i; i -= lowbit(i))
	{
		res += c[i];
	}
	return res;
}

int main()
{
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= n; i++)
	{
		int x;
		scanf("%d", &x);
		update(i, x);
	}
	while (m--)
	{
		scanf("%d%d%d", &op, &x, &y);
		if (op == 1)
		{
			update(x, y);
		}
		else
		{
			printf("%d\n", query(y) - query(x - 1));
		}
	}
	return 0;
}

2. 区间修改,单点查询

P3368 【模板】树状数组 2

  • 区间修改

这里用到了差分的思想

设原数组为 \(a\),差分数组 \(d[i] = a[i] - a[i-1]\)

当给区间 \([l, r]\) 加上 \(k\) 时,\(a[l]\)\(a[l-1]\) 的差增加了 \(k\)\(a[r+1]\)\(a[r]\) 减少了 \(k\),故只需将 \(d[l]\) 加上 \(k\)\(a[r+1]\) 减去 \(k\) 即可。

void update(int x, int k)
{
	for (int i = x; i <= n; i += lowbit(i))
	{
		c[i] += k;
	}
}

void range_update(int x, int y, int k)
{
	update(x, k);
	update(y + 1, -k);
}
  • \(AC\;Code\)
#include <iostream>
#include <cstdio>
using namespace std;

const int MAXN = 5e5 + 5;

int n, m, now, last, op, x, y, k;
int c[MAXN];

int lowbit(int x)
{
	return x & -x;
}

void update(int x, int k)
{
	for (int i = x; i <= n; i += lowbit(i))
	{
		c[i] += k;
	}
}

void range_update(int x, int y, int k)
{
	update(x, k);
	update(y + 1, -k);
}

int query(int x)
{
	int res = 0;
	for (int i = x; i; i -= lowbit(i))
	{
		res += c[i];
	}
	return res;
}

int main()
{
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= n; i++)
	{
		scanf("%d", &now);
		update(i, now - last);
		last = now;
	}
	while (m--)
	{
		scanf("%d%d", &op, &x);
		if (op == 1)
		{
			scanf("%d%d", &y, &k);
			range_update(x, y, k);
		}
		else
		{
			printf("%d\n", query(x));
		}
	}
	return 0;
}

3. 区间修改,区间查询

P3372 【模板】线段树 1

位置 \(x\) 的前缀和

\(\sum\limits_{i=1}^{x}a[i]=\sum\limits_{i=1}^{x}\sum\limits_{j=1}^{i}d[j]\)

其中,\(d[1]\) 被用了 \(x\) 次,\(d[2]\) 被用了 \(x-1\)\(……\) 那么可以写出:

\(\sum\limits_{i=1}^{x}\sum\limits_{j=1}^{i}d[j]=\sum\limits_{i=1}^{x}d[i]\cdot(x-i+1)\)

\(\qquad\qquad\ \ =(x+1)\cdot\sum\limits_{i=1}^{x}d[i]-\sum\limits_{i=1}^{x}d[i]\cdot i\)

那么可以维护两个数组的前缀和:

一个数组是 \(c1[i]=d[i]\)

另一个数组是 \(c2[i]=d[i]\cdot i\)

  • 区间修改

对于 \(c1\) 数组的修改同上对 \(d\) 数组的修改。

对于 \(c2\) 数组的修改也类似,我们给 \(c2[l]\) 加上 \(l \cdot x\),给 \(c2[r + 1]\) 减去 \((r + 1) \cdot x\)

void update(int x, int k)
{
	for (int i = x; i <= n; i += lowbit(i))
	{
		c1[i] += k;
		c2[i] += k * x;
	}
}

void range_update(int x, int y, int k)
{
	update(x, k);
	update(y + 1, -k);
}
  • 区间查询

位置 \(x\) 的前缀和即:\((x+1)\cdot t1\) 数组中 \(x\) 的前缀和 \(- t2\) 数组中 \(x\) 的前缀和。

int query(int x)
{
	int res = 0;
	for (int i = x; i; i -= lowbit(i))
	{
		res += (x + 1) * c1[i] - c2[i];
	}
	return res;
}

int range_query(int x, int y)
{
	return query(y) - query(x - 1);
}
  • \(AC\;Code\)
#include <iostream>
#include <cstdio>
#define int long long
using namespace std;

const int MAXN = 1e5 + 5;

int n, m, a, last, op, x, y, k;
int c1[MAXN], c2[MAXN];

int lowbit(int x)
{
	return x & -x;
}

void update(int x, int k)
{
	for (int i = x; i <= n; i += lowbit(i))
	{
		c1[i] += k;
		c2[i] += k * x;
	}
}

void range_update(int x, int y, int k)
{
	update(x, k);
	update(y + 1, -k);
}

int query(int x)
{
	int res = 0;
	for (int i = x; i; i -= lowbit(i))
	{
		res += (x + 1) * c1[i] - c2[i];
	}
	return res;
}

int range_query(int x, int y)
{
	return query(y) - query(x - 1);
}

signed main()
{
	scanf("%lld%lld", &n, &m);
	for (int i = 1; i <= n; i++)
	{
		scanf("%lld", &a);
		update(i, a - last);
		last = a;
	}
	while (m--)
	{
		scanf("%lld%lld%lld", &op, &x, &y);
		if (op == 1)
		{
			scanf("%lld", &k);
			range_update(x, y, k);
		}
		else
		{
			printf("%lld\n", range_query(x, y));
		}
	}
	return 0;
}

4. 二维树状数组

1.单点修改,区间查询

思路跟一维类似,只不过一层循环变两层

  • 区间查询

查询时要用到容斥原理

int query(int x, int y)
{
    int res = 0;
    for (int i = x; i; i -= lowbit(i))
    {
        for (int j = y; j; j -= lowbit(j))
        {
            res += c[i][j];
        }
    }
    return res;
}

int range_query(int a, int b, int c, int d) //左上角 (a,b) 右下角 (c,d)
{
	return query(c, d) - query(c, b - 1) - query(a - 1, d) + query(a - 1, b - 1);
}

2. 区间修改,单点查询

需要用到二维差分。

  • 区间修改

令差分数组
\(d[i][j]=a[i][j] - a[i-1][j] - a[i][j-1]+a[i-1][j-1]\)

当我们给一个 \(5\times5\) 的矩阵最中间的 \(3\times3\) 的矩阵加上 \(k\) 时,差分数组变成了这样:

0  0  0  0  0
0 +k  0  0 -k
0  0  0  0  0
0  0  0  0  0
0 -k  0  0 +k

这样修改差分数组,原数组就变成了:

0  0  0  0  0
0  k  k  k  0
0  k  k  k  0
0  k  k  k  0
0  0  0  0  0
void update(int x, int y, int k)
{
    for (int i = x; i <= n; i += lowbit(i))
    {
        for (int j = y; j <= m; j += lowbit(j))
        {
            c[i][j] += k;
        }
    }
}

void range_update(int a, int b, int c, int d, int k)
{
	update(a, b, k);
	update(a, d + 1, -k);
	update(c + 1, b, -k);
	update(c + 1, d + 1, k);
}

3. 区间修改,区间查询

最最最最\(\color{White}{恶心}\)重要的部分终于来啦!

下面这个式子表示的是点 \((x, y)\) 的二维前缀和:

\(\sum\limits_{i=1}^{x}\sum\limits_{j=1}^{y}\sum\limits_{k=1}^{i}\sum\limits_{h=1}^{j}d[h][k]\)

\(d[1][1]\) 出现了 \(x\cdot y\) 次,\(d[1][2]\) 出现了 \(x \cdot (y-1)\)\(…… d[h][k]\) 出现了 \((x-h+1)\cdot(y-k+1)\) 次。

原式 \(=\sum\limits_{i=1}^{x}\sum\limits_{j=1}^{y}d[i][j]\cdot(x+1-i)\cdot(y+1-j)\)

\(\quad\ \ \ \,=(x+1)\cdot(y+1)\cdot\sum\limits_{i=1}^{x}\sum\limits_{j=1}^{y}d[i][j]-(y+1)\cdot\sum\limits_{i=1}^{x}\sum\limits_{j=1}^{y}d[i][j]\cdot i-(x+1)\cdot\sum\limits_{i=1}^{x}\sum\limits_{j=1}^{y}d[i][j]\cdot j+\sum\limits_{i=1}^{x}\sum\limits_{j=1}^{y}d[i][j]\cdot i\cdot j\)

我们要开 \(4\)\(c\) 数组,分别维护:

\(d[i][j],d[i][j]\cdot i,d[i][j]\cdot j,d[i][j]\cdot i\cdot j\)

void update(int x, int y, int k)
{
    for (int i = x; i <= n; i += lowbit(i))
    {
        for (int j = y; j <= m; j += lowbit(j))
        {
            c1[i][j] += k;
            c2[i][j] += k * x;
            c3[i][j] += k * y;
            c4[i][j] += k * x * y;
        }
    }
}
 
void range_update(int a, int b, int c, int d, int k)
{
    update(a, b, k);
    update(a, d + 1, -k);
    update(c + 1, b, -k);
    update(c + 1, d + 1, k);
}
 
int query(int x, int y)
{
    int res = 0;
    for (int i = x; i; i -= lowbit(i))
    {
        for (int j = y; j; j -= lowbit(j))
        {
            res += (x + 1) * (y + 1) * c1[i][j] - (y + 1) * c2[i][j] - (x + 1) * c3[i][j] + c4[i][j];
        }
    }
    return res;
}
 
int range_query(int a, int b, int c, int d)
{
    return query(c, d) - query(a - 1, d) - query(c, b - 1) + query(a - 1, b - 1);
}

P4514 上帝造题的七分钟

\(AC\;Code\)

#include <iostream>
#include <cstdio>
using namespace std;

const int MAXN = 2050;

int n, m;
int c1[MAXN][MAXN], c2[MAXN][MAXN], c3[MAXN][MAXN], c4[MAXN][MAXN];

int lowbit(int x)
{
	return x & -x;
}

void update(int x, int y, int k)
{
	for (int i = x; i <= n; i += lowbit(i))
	{
		for (int j = y; j <= m; j += lowbit(j))
		{
			c1[i][j] += k;
			c2[i][j] += y * k;
			c3[i][j] += x * k;
			c4[i][j] += x * y * k;
		}
	}
}

void range_update(int a, int b, int c, int d, int k)
{
	update(a, b, k);
	update(a, d + 1, -k);
	update(c + 1, b, -k);
	update(c + 1, d + 1, k);
}

int query(int x, int y)
{
	int res = 0;
	for (int i = x; i; i -= lowbit(i))
	{
		for (int j = y; j; j -= lowbit(j))
		{
			res += (x + 1) * (y + 1) * c1[i][j] - (x + 1) * c2[i][j] - (y + 1) * c3[i][j] + c4[i][j];
		}
	}
	return res;
}

int range_query(int a, int b, int c, int d)
{
	return query(c, d) - query(a - 1, d) - query(c, b - 1) + query(a - 1, b - 1);
}

int main()
{
	getchar();
	scanf("%d%d", &n, &m);
	char op;
	int a, b, c, d, k;
	while (~scanf("\n%c%d%d%d%d", &op, &a, &b, &c, &d))
	{
		if (op == 'L')
		{
			scanf("%d", &k);
			range_update(a, b, c, d, k);
		}
		else
		{
			printf("%d\n", range_query(a, b, c, d));
		}
	}
	return 0;
}

\(\mathfrak{THE\;END}\)

posted @ 2021-08-07 17:46  mango09  阅读(70)  评论(0编辑  收藏  举报
-->