快速求区间和的有趣算法——树状数组
好久没写东西,感觉有写些什么的必要了。(高仿鲁迅)
树状数组虽然听起来名字高大上,但是不是很难(前缀和是名字高大上,却水得像海洋)
树状数组在单纯的查询一个区间的和和修改某一个数的效率要超过线段树哦!树状数组最差时间复杂度为O(logn),而线段树的时间复杂度一直保持O(logn),且线段树的空间复杂度是树状数组的4倍。
But:树状数组只是线段树的一个辣鸡版本(虽然在某些方面比线段数快一点点),使用树状数组很大的一个原因是树状数组十分好写,且非常好维护。但是它只能处理可以用前缀和或差分来解决的题目,像是求(l,r)之间的最大值,树状数组就会Game Over。
为什么叫树状数组呢?因为它长得像树一样(废话),就像这个样子:
表示我的画图技术和画图软件都烂炸了(逃)
现在假如有n个数,存在A数组里,用C数组当树状数组,从A[1]开始存入,一直存到A[n],然后顺便把C数组初始化。(一会儿解释为什么不从A[0]开始存)
通过看图,可以得到这么一个结论:
C1 = A1
C2 = A1+A2
C3 = A3
C4 = A1+A2+A3+A4
C5 = A5
C6 = A5+A6
C7 = A7
C8 = A1+A2+A3+A4+A5+A6+A7+A8
现在找找规律!
好吧,是不是有感觉了但是表达不出来?
再处理一下,把C数组的下标用二进制表示出来
C(1) = A1
C(10) = A1+A2
C(11) = A3
C(100) = A1+A2+A3+A4
C(101) = A5
C(110) = A5+A6
C(111) = A7
C(1000) = A1+A2+A3+A4+A5+A6+A7+A8
我们把这些下标的二进制从后面往前面看,看到出现一个1为止:
C(1) = A1
C(10) = A1+A2
C(1) = A3
C(100) = A1+A2+A3+A4
C(1) = A5
C(10) = A5+A6
C(1) = A7
C(1000) = A1+A2+A3+A4+A5+A6+A7+A8
最后一个处理,把读后的二进制下标再转换成十进制:
C(1) = A1
C(2) = A1+A2
C(1) = A3
C(4) = A1+A2+A3+A4
C(1) = A5
C(2) = A5+A6
C(1) = A7
C(8) = A1+A2+A3+A4+A5+A6+A7+A8
现在绝对能看懂了(自信满满)
好了,“显然”可以看出,假如原来C数组的下标为a,现在的下标为b,那么这个C[a]就对应着从A[b]和它前面总共b个数的和,或者可以说,对应着从A[a-b+1]到A[b]的数的和。(原本很简单的东西为啥讲出来这么麻烦?)
有人要说了:讲这么半天,你也没有告诉我怎么初始化(玩)树状数组。
好吧好吧,现在开始讲(che)解(dan)。
还是从一道模板题开始最好了QvQ!(然后从一道模板题结束......)
题目描述
如题,已知一个数列,你需要进行下面两种操作:
1.将某一个数加上x
2.求出某区间每一个数的和
输入输出格式
输入格式:
第一行包含两个整数N、M,分别表示该数列数字的个数和操作的总个数。
第二行包含N个用空格分隔的整数,其中第i个数字表示数列第i项的初始值。
接下来M行每行包含3个整数,表示一个操作,具体如下:
操作1: 格式:1 x k 含义:将第x个数加上k
操作2: 格式:2 x y 含义:输出区间[x,y]内每个数的和
输出格式:
输出包含若干行整数,即为所有操作2的结果。
输入输出样例
输入样例:
5 5
1 5 4 2 3
1 1 3
2 2 5
1 3 -1
1 4 2
2 1 4
输出样例:
14
16
说明
时空限制:1000ms,128M
数据规模:
对于30%的数据:N<=8,M<=10
对于70%的数据:N<=10000,M<=10000
对于100%的数据:N<=500000,M<=500000
这道题明显是要用树状数组做嘛!(笑)
题目的意思非常好懂,就是n个数,m个操作。操作分两种,一种是查询区间和,一种是修改(增加)第几个数的值。
那么开始码代码吧,先从主函数(main)开始:
因为A数组里的所有值在C数组里都能查询到,所以并不需要建一个A数组,只需要读一个数,然后把C数组更新一下便好了。
代码如下:
cin>>n>>m;//分别是n个数m个操作
for(int i=1;i<=n;i++){
int v;
cin>>v;
update(i,v);//这个函数是在序列(假想的A数组)第i个位置加上v,因为初始都是零,所以相当于初始化,这个函数的实现后面讲
}
读入之后,就是m个操作了:
for(int i=1;i<=m;i++){
int k,a,b;
cin>>k>>a>>b;//k是模式(题中有),a和b下面要用到
if(k==1)
update(a,b);//在序列的第a个位置加上b
else//如果不是在某个位置上加一个数,就是求区间和啦
cout<<sum(b)-sum(a-1)<<endl;//输出区间和,这个看不懂不要紧,后面讲(怎么什么都后面讲poq)
}
好吧,现在主函数除了return 0;以外都写完了,现在就到了讲(che)最难的update和sum函数了(其实还有一个lowbit函数)
先回到那个(有趣的)图:
唉,丑陋不堪!
再把最开始得到的结论搬出来:
C1 = A1
C2 = A1+A2
C3 = A3
C4 = A1+A2+A3+A4
C5 = A5
C6 = A5+A6
C7 = A7
C8 = A1+A2+A3+A4+A5+A6+A7+A8
还有这个:
C(1) = A1
C(10) = A1+A2
C(11) = A3
C(100) = A1+A2+A3+A4
C(101) = A5
C(110) = A5+A6
C(111) = A7
C(1000) = A1+A2+A3+A4+A5+A6+A7+A8
举个栗子,假如我们修改了A[3]的值,C数组中的哪些元素需要修改呢?
通过看图和看图后得到的结论,“显然“就是包含A[3]的C数组的元素,或者说是C[3]和它的”祖先”(反正人们说树都喜欢用祖先这个词),也就是C[3],C[4]和C[8]。
因为你不能给计算机一个xxx.jpg然后让它自动修改需要修改的C数组的元素是不是?所以现在,得到一个递推式来自动处理显得很有必要了。
现在你不用自己找了,因为已经有人帮你找好了。
我们设x下标的二进制从后面往前面看,看到出现一个1时,我们看过的二进制为lowbit(x),如,3的二进制是11,那么lowbit(3)便是1了,又如4的二进制是100,那么lowbit(4)就是100了。
如果我们把3加上lowbit(3),得到4,再把4加上lowbit(4),就得到我们要的8了,这样,就愉快地把要修改的C数组的元素全部找到了。
先把lowbit函数给写了吧:
int lowbit(int x){
return x& (-x);
}
至于这个lowbit里面是怎么回事,因为涉及到补码什么的,就不讲了,反正也很好记๑乛◡乛๑
不过不能一直加下去吧?边界条件很好找,就是x不会超过n(显然易见的)。
现在把update函数也放出来:
void update(int x,int v){
while(x<=n){//边界条件
c[x]+=v;//将要更新的C数组的元素加上v
x+=lowbit(x);//下一个元素
}
}
之前有一个问题,就是为什么A数组不从0开始,因为lowbit(0)等于0,那么就会永远达不到边界条件,也就是x永远也不会达到n,总之会无限循环下去,就炸了,炸了!
好啦,现在唯一没有讲(che)的是主函数中的这句话了:
cout<<sum(b)-sum(a-1)<<endl;
很简单,sum(x)函数是计算序列中第1个数到第x个数的和的函数(绕晕),和前缀和的思想相同,若想求第a个数到第b个数的和,只需要求第1个数到第b个数的和减第1个数到第a-1个数的和即可
那么是时候讲(che)sum函数的构造了!
假如要求序列中第1个数到第7个数的和该怎么弄?看看表就明白了——>C[7]+C[6]+C[4],再拆成二进制C(111)+C(110)+C(100)。那么假如要求序列中第1个数到第6个数的和呢?再看一下表C[6]+C[4],再拆成二进制C(110)+C(100)。
可以看出来,要求第1个数到第x个数的和,只需要从x开始向下递推,然后用一个变量将一堆C[x]加起来,就可以得到第1个数到第x个数的和了,边界条件也是“显然易见”的,那就是x>0或x>=1。
话不多说,上代码:
int sum(int x){
int res=0;//保存一堆C[x]的和的变量
while(x>0){//边界条件
res+=c[x];//加上......
x-=lowbit(x);//下一个
}
return res;
}
这样,这道题就可以AC了!
附上完整代码:
#pragma GCC optimize(3)
#include<bits/stdc++.h>
using namespace std;
static int n,m;
static int c[500005];
inline int lowbit(int x){
return x& (-x);
}
void update(int x,int v){
while(x<=n){
c[x]+=v;
x+=lowbit(x);
}
}
int sum(int x){
int res=0;
while(x>0){
res+=c[x];
x-=lowbit(x);
}
return res;
}
int main(){
cin>>n>>m;
for(int i=1;i<=n;i++){
int v;
cin>>v;
update(i,v);
}
for(int i=1;i<=m;i++){
int k,a,b;
cin>>k>>a>>b;
if(k==1)
update(a,b);
else
cout<<sum(b)-sum(a-1)<<endl;
}
return 0;
}
请无视我手动开的O3和C++17中全局变量必须加的static......(逃)