树状数组
1. 什么是树状数组?
树状数组是一个查询和修改复杂度都为 \(\operatorname{O}(\log n)\) 的数据结构。
看到这句话是不是想到了线段树?
是的!
但是,凡是可以使用树状数组解决的问题, 使用线段树一定可以解决, 但是线段树能够解决的问题树状数组未必能够解决。
哦,那还是用线段树吧……
然鹅:
-
线段树的数组需要开 \(4\) 倍,树状数组只用 \(1\) 倍。
-
树状数组代码量少,线段树写 \(1\) 题用树状数组可以写 \(2\) 题。
-
真不错!
2. \(lowbit\)
\(\operatorname{lowbit(x)}\) 就是求 \(x\) 最低位的 \(1\)
用 \(c\) 数组记录原数组的和:
\(c[1]=a[1]\to c[(1)_2]=a[(1)_2]\)
\(c[2]=a[2]+a[1]\to c[(10)_2]=a[(10)_2]+a[(1)_2]\)
\(c[3]=a[3]\to c[(11)_2]=a[(11)_2]\)
\(c[4]=a[4]+a[3]+a[2]+a[1]\to c[(100)_2]=a[(100)_2]+a[(11)_2]+a[(10)_2]+a[(1)_2]\)
\(c[5]=a[5]\to c[(101)_2]=a[(101)_2]\)
\(c[6]=a[6]+a[5]\to c[(110)_2]=a[(110)_2]+a[(101)_2]\)
\(c[7]=a[7]\to c[(111)_2]=a[(111)_2]\)
更新:
以 \(a[1]\) 为例,要更新 \(a[1],a[2],a[4]\)。
是每次加了二进制的低位 \(1\)
\(1\to1+1=10\to10+10=100\to100+100=1000>111\)(停止)
查询:
查询 \(\sum\limits_{i=1}^7{a[i]}\to\) 查询 \(c[7]+c[6]+c[4]\)
是每次去掉了二进制中的低位 \(1\)
\(111\to111-1=110\to110-10=100\to100-100=0\)(停止)
\(lowbit\) 的实现 :
int lowbit(int x)
{
return x & -x;
}
我们知道,一个数的负数就等于对这个数取反 \(+1\)
以二进制数 \(11010\) 为例:\(11010\) 取反为 \(00101\),加 \(1\) 后为 \(00110\),两者相与便是最低位的 \(1\)。
3. 一维树状数组
1. 单点修改,区间查询
- 单点修改
由刚才所说可知:
void update(int x, int k) //x 为要更新的数,k 为要加上的值,n 为数组大小
{
for (int i = x; i <= n; i += lowbit(i))
{
c[i] += k;
}
}
- 区间查询
其实也就是单点更新的逆操作。
要想求出 \(x\sim y\) 相当于 \(1\sim y\) 减去 \(1\sim x-1\)
int query(int x)
{
int res = 0;
for (int i = x; i; i -= lowbit(i))
{
res += c[i];
}
return res;
}
int range_query(int x, int y)
{
return query(y) - query(x - 1);
}
- \(AC\;Code\)
#include <iostream>
#include <cstdio>
using namespace std;
const int MAXN = 5e5 + 5;
int n, m, op, x, y;
int c[MAXN];
int lowbit(int x)
{
return x & -x;
}
void update(int x, int k)
{
for (int i = x; i <= n; i += lowbit(i))
{
c[i] += k;
}
}
int query(int x)
{
int res = 0;
for (int i = x; i; i -= lowbit(i))
{
res += c[i];
}
return res;
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++)
{
int x;
scanf("%d", &x);
update(i, x);
}
while (m--)
{
scanf("%d%d%d", &op, &x, &y);
if (op == 1)
{
update(x, y);
}
else
{
printf("%d\n", query(y) - query(x - 1));
}
}
return 0;
}
2. 区间修改,单点查询
- 区间修改
这里用到了差分的思想
设原数组为 \(a\),差分数组 \(d[i] = a[i] - a[i-1]\)
当给区间 \([l, r]\) 加上 \(k\) 时,\(a[l]\) 与 \(a[l-1]\) 的差增加了 \(k\),\(a[r+1]\) 与 \(a[r]\) 减少了 \(k\),故只需将 \(d[l]\) 加上 \(k\),\(a[r+1]\) 减去 \(k\) 即可。
void update(int x, int k)
{
for (int i = x; i <= n; i += lowbit(i))
{
c[i] += k;
}
}
void range_update(int x, int y, int k)
{
update(x, k);
update(y + 1, -k);
}
- \(AC\;Code\)
#include <iostream>
#include <cstdio>
using namespace std;
const int MAXN = 5e5 + 5;
int n, m, now, last, op, x, y, k;
int c[MAXN];
int lowbit(int x)
{
return x & -x;
}
void update(int x, int k)
{
for (int i = x; i <= n; i += lowbit(i))
{
c[i] += k;
}
}
void range_update(int x, int y, int k)
{
update(x, k);
update(y + 1, -k);
}
int query(int x)
{
int res = 0;
for (int i = x; i; i -= lowbit(i))
{
res += c[i];
}
return res;
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++)
{
scanf("%d", &now);
update(i, now - last);
last = now;
}
while (m--)
{
scanf("%d%d", &op, &x);
if (op == 1)
{
scanf("%d%d", &y, &k);
range_update(x, y, k);
}
else
{
printf("%d\n", query(x));
}
}
return 0;
}
3. 区间修改,区间查询
位置 \(x\) 的前缀和
\(\sum\limits_{i=1}^{x}a[i]=\sum\limits_{i=1}^{x}\sum\limits_{j=1}^{i}d[j]\)
其中,\(d[1]\) 被用了 \(x\) 次,\(d[2]\) 被用了 \(x-1\) 次 \(……\) 那么可以写出:
\(\sum\limits_{i=1}^{x}\sum\limits_{j=1}^{i}d[j]=\sum\limits_{i=1}^{x}d[i]\cdot(x-i+1)\)
\(\qquad\qquad\ \ =(x+1)\cdot\sum\limits_{i=1}^{x}d[i]-\sum\limits_{i=1}^{x}d[i]\cdot i\)
那么可以维护两个数组的前缀和:
一个数组是 \(c1[i]=d[i]\)
另一个数组是 \(c2[i]=d[i]\cdot i\)
- 区间修改
对于 \(c1\) 数组的修改同上对 \(d\) 数组的修改。
对于 \(c2\) 数组的修改也类似,我们给 \(c2[l]\) 加上 \(l \cdot x\),给 \(c2[r + 1]\) 减去 \((r + 1) \cdot x\)。
void update(int x, int k)
{
for (int i = x; i <= n; i += lowbit(i))
{
c1[i] += k;
c2[i] += k * x;
}
}
void range_update(int x, int y, int k)
{
update(x, k);
update(y + 1, -k);
}
- 区间查询
位置 \(x\) 的前缀和即:\((x+1)\cdot t1\) 数组中 \(x\) 的前缀和 \(- t2\) 数组中 \(x\) 的前缀和。
int query(int x)
{
int res = 0;
for (int i = x; i; i -= lowbit(i))
{
res += (x + 1) * c1[i] - c2[i];
}
return res;
}
int range_query(int x, int y)
{
return query(y) - query(x - 1);
}
- \(AC\;Code\)
#include <iostream>
#include <cstdio>
#define int long long
using namespace std;
const int MAXN = 1e5 + 5;
int n, m, a, last, op, x, y, k;
int c1[MAXN], c2[MAXN];
int lowbit(int x)
{
return x & -x;
}
void update(int x, int k)
{
for (int i = x; i <= n; i += lowbit(i))
{
c1[i] += k;
c2[i] += k * x;
}
}
void range_update(int x, int y, int k)
{
update(x, k);
update(y + 1, -k);
}
int query(int x)
{
int res = 0;
for (int i = x; i; i -= lowbit(i))
{
res += (x + 1) * c1[i] - c2[i];
}
return res;
}
int range_query(int x, int y)
{
return query(y) - query(x - 1);
}
signed main()
{
scanf("%lld%lld", &n, &m);
for (int i = 1; i <= n; i++)
{
scanf("%lld", &a);
update(i, a - last);
last = a;
}
while (m--)
{
scanf("%lld%lld%lld", &op, &x, &y);
if (op == 1)
{
scanf("%lld", &k);
range_update(x, y, k);
}
else
{
printf("%lld\n", range_query(x, y));
}
}
return 0;
}
4. 二维树状数组
1.单点修改,区间查询
思路跟一维类似,只不过一层循环变两层
- 区间查询
查询时要用到容斥原理
int query(int x, int y)
{
int res = 0;
for (int i = x; i; i -= lowbit(i))
{
for (int j = y; j; j -= lowbit(j))
{
res += c[i][j];
}
}
return res;
}
int range_query(int a, int b, int c, int d) //左上角 (a,b) 右下角 (c,d)
{
return query(c, d) - query(c, b - 1) - query(a - 1, d) + query(a - 1, b - 1);
}
2. 区间修改,单点查询
需要用到二维差分。
- 区间修改
令差分数组
\(d[i][j]=a[i][j] - a[i-1][j] - a[i][j-1]+a[i-1][j-1]\)
当我们给一个 \(5\times5\) 的矩阵最中间的 \(3\times3\) 的矩阵加上 \(k\) 时,差分数组变成了这样:
0 0 0 0 0
0 +k 0 0 -k
0 0 0 0 0
0 0 0 0 0
0 -k 0 0 +k
这样修改差分数组,原数组就变成了:
0 0 0 0 0
0 k k k 0
0 k k k 0
0 k k k 0
0 0 0 0 0
void update(int x, int y, int k)
{
for (int i = x; i <= n; i += lowbit(i))
{
for (int j = y; j <= m; j += lowbit(j))
{
c[i][j] += k;
}
}
}
void range_update(int a, int b, int c, int d, int k)
{
update(a, b, k);
update(a, d + 1, -k);
update(c + 1, b, -k);
update(c + 1, d + 1, k);
}
3. 区间修改,区间查询
最最最最\(\color{White}{恶心}\)重要的部分终于来啦!
下面这个式子表示的是点 \((x, y)\) 的二维前缀和:
\(\sum\limits_{i=1}^{x}\sum\limits_{j=1}^{y}\sum\limits_{k=1}^{i}\sum\limits_{h=1}^{j}d[h][k]\)
\(d[1][1]\) 出现了 \(x\cdot y\) 次,\(d[1][2]\) 出现了 \(x \cdot (y-1)\) 次 \(…… d[h][k]\) 出现了 \((x-h+1)\cdot(y-k+1)\) 次。
原式 \(=\sum\limits_{i=1}^{x}\sum\limits_{j=1}^{y}d[i][j]\cdot(x+1-i)\cdot(y+1-j)\)
\(\quad\ \ \ \,=(x+1)\cdot(y+1)\cdot\sum\limits_{i=1}^{x}\sum\limits_{j=1}^{y}d[i][j]-(y+1)\cdot\sum\limits_{i=1}^{x}\sum\limits_{j=1}^{y}d[i][j]\cdot i-(x+1)\cdot\sum\limits_{i=1}^{x}\sum\limits_{j=1}^{y}d[i][j]\cdot j+\sum\limits_{i=1}^{x}\sum\limits_{j=1}^{y}d[i][j]\cdot i\cdot j\)
我们要开 \(4\) 个 \(c\) 数组,分别维护:
\(d[i][j],d[i][j]\cdot i,d[i][j]\cdot j,d[i][j]\cdot i\cdot j\)
void update(int x, int y, int k)
{
for (int i = x; i <= n; i += lowbit(i))
{
for (int j = y; j <= m; j += lowbit(j))
{
c1[i][j] += k;
c2[i][j] += k * x;
c3[i][j] += k * y;
c4[i][j] += k * x * y;
}
}
}
void range_update(int a, int b, int c, int d, int k)
{
update(a, b, k);
update(a, d + 1, -k);
update(c + 1, b, -k);
update(c + 1, d + 1, k);
}
int query(int x, int y)
{
int res = 0;
for (int i = x; i; i -= lowbit(i))
{
for (int j = y; j; j -= lowbit(j))
{
res += (x + 1) * (y + 1) * c1[i][j] - (y + 1) * c2[i][j] - (x + 1) * c3[i][j] + c4[i][j];
}
}
return res;
}
int range_query(int a, int b, int c, int d)
{
return query(c, d) - query(a - 1, d) - query(c, b - 1) + query(a - 1, b - 1);
}
\(AC\;Code\)
#include <iostream>
#include <cstdio>
using namespace std;
const int MAXN = 2050;
int n, m;
int c1[MAXN][MAXN], c2[MAXN][MAXN], c3[MAXN][MAXN], c4[MAXN][MAXN];
int lowbit(int x)
{
return x & -x;
}
void update(int x, int y, int k)
{
for (int i = x; i <= n; i += lowbit(i))
{
for (int j = y; j <= m; j += lowbit(j))
{
c1[i][j] += k;
c2[i][j] += y * k;
c3[i][j] += x * k;
c4[i][j] += x * y * k;
}
}
}
void range_update(int a, int b, int c, int d, int k)
{
update(a, b, k);
update(a, d + 1, -k);
update(c + 1, b, -k);
update(c + 1, d + 1, k);
}
int query(int x, int y)
{
int res = 0;
for (int i = x; i; i -= lowbit(i))
{
for (int j = y; j; j -= lowbit(j))
{
res += (x + 1) * (y + 1) * c1[i][j] - (x + 1) * c2[i][j] - (y + 1) * c3[i][j] + c4[i][j];
}
}
return res;
}
int range_query(int a, int b, int c, int d)
{
return query(c, d) - query(a - 1, d) - query(c, b - 1) + query(a - 1, b - 1);
}
int main()
{
getchar();
scanf("%d%d", &n, &m);
char op;
int a, b, c, d, k;
while (~scanf("\n%c%d%d%d%d", &op, &a, &b, &c, &d))
{
if (op == 'L')
{
scanf("%d", &k);
range_update(a, b, c, d, k);
}
else
{
printf("%d\n", range_query(a, b, c, d));
}
}
return 0;
}