高级数据结构之树状数组
————————————————————————这些是转的,出处不明———————————————————————————————
树状数组比较适合单个元素改变,反复求部分和,或者区间更新,单点求值。
先看的是一维的树状数组。
树状数组是一个很天才的想法,考虑这样的一种情景,对于一组数据,你经常要求他们某个区间的和,而却这组数据里的元素会经常的改变,最朴素的想法就是暴力,O(1)的修改,O(n)的查询,或O(n)的修改O(1)的查询(就是记录)。第二种想法就是线段树,查询和修改的复杂度都是O(logn),线段树的编程复杂度比较高,常数因子也较大。有一种时间复杂度也是O(logn)的而且编程复杂度很简单的方法,就是用树状数组。树状数组的灵感是来源于二进制、线段树和O(1)查询O(n)修改算法(其实是我自己的灵感啦哈哈),二进制有01组成,每一个数字都有自己对应的一个二进制,既然线段数是把数据按二分的思想,把区间分成两个一样大小的区间,把大问题分解成两个小的子问题,那么在一组规模大小是10010110的数据,同样的我们也可以把区间分成一个个子区间,把大问题分解成一个个小问题,那么要怎么分解呢?看这二进制就明白了,我们要把区间分解成一个个大小不一的子区间,使它们加起来刚好就是原来的区间,很明显这个二进制可以分成:
如果我们要求1到10010100的和,即可用四个子区间组成原问题
[1,10]
[11,100]
[1001,10000]
[10001,10010100]
把它们加起来就是了,这就是说我们每次都把10010110最右边的“1”拿出来,作为子区间的右边界,左边界就是前一个子区间的右边界加1,第一个子区间的左边界是1.用位运算,把最右边的1分解,i&(i^(i-1))即i&(-i),这是自低向上的把区间分解出来,接下来就是定义tree[i]来递推了,按照上面的分法,很难定义,所以我们自顶向下的分区间试试看能不能容易的定义,分的四个区间:
[10010101,10010110]
[10010001,10010100]
[10000001,10010000]
[00000001,10000000]
这样定义起来就很方便了,我們可以看到右區間就是每往下都去掉一個最右邊的1,tree[i]就是区间[i-(i&-i)+1,i]的和,由这里的定义知道,树状数组要从1开始计数,而不是c/c++一向的0开始。求和函数就这样写:
1 int get_sum(int i) 2 { 3 int sum = 0; 4 while(i) { 5 sum += tree[i]; 6 i -= i & -i; 7 } 8 return sum; 9 }
这样的话,求区间[i,j]的值就是get_sum(j)-get_sum(i-1)了。如果经常求一个单个元素的指,这样写就重复计算了。我们可以在tree[i]所覆盖的范围中,减去除a[i]以外的(一般树状数组是不保存原始数据a[i]的,一般而已),因为每次都是减去最右边的1然后加1,就是减的范围不能小于等于i-(i&-i),很明显,i-1>=i-(i&-i),也就是说只要这个数字i'在范围内(不包括等于),那么它所覆盖的范围不会超过i-(i&-i),因为它怎么减去最右边的1,至少等于i-(i&-i),而覆盖范围的下界是i'-(i'&-i')+1,所以我们可以这样写:
1 int get_single(int i) 2 { 3 int s = tree[i],z=i-(i&-i); 4 --i; 5 while(i>z) { 6 s -= tree[i]; 7 i -= i&-i; 8 } 9 return s; 10 }
查询已经搞定了,现在就看看修改的时候要怎么做了。先看看树状数组的图是怎样的。
很明显,当我修改了a[i]的值,那么最小的,收到影响的首当其冲是tree[i],接着就是一级级往上影响它的父亲节点。所以这里的重点就是在于怎么找一个节点的父亲节点。我们知道,一个节点的父亲节点序号肯定比它大,那么我们就是要找一个范围能覆盖i的最小的j,那个j就是i的父节点,那么就很明显i的父节点不会是由i最右边1的右边的0变成1变成的,因为那样的j把最右1去掉再加1后,覆盖范围刚好覆盖不到i,所以就只能由最右边1的左边的0变成1,而现在要找的是最小的,因此我们只要把加上一个最右1就可以了,代码这样写:
1 void modify(int i,int c) 2 { 3 while(i<=MIX) { //MIX代表树状数组最大的编号 4 tree[i] += c; 5 i += i&-i; 6 } 7 }
查询和修改的时间复杂度都是O(logn)。
树状数组的主要操作函数就是修改&&求和,那么就是说还可以用在统计计数方面的情景。通常对于这些统计计数的情景,遍历的顺序挺重要的,有时候前到后遍历简单,有时候后到前遍历简单。
一维的树状数组就是这样,下面就先来看两题水题,来看看怎么用树状数组,poj2352。
题目的大意就是说给你一堆星星的坐标(x,y),然后要你输出每一层的星星的数量,层的意思是有多少颗星的x和y不大于这颗星星。由于他输入的时候已经是按y从小到大的输入,所以层数的计算,我们只需要判断当前输入的这颗星星的x坐标,大过他前面输入的多少颗星星就可以了。所以我们令tree[i]是x坐标是i的星星个数,要求层数的时候,就get_sum就可以了。具体的代码如下:
1 #include<stdio.h> 2 #include<memory.h> 3 #define MIX 32001 4 int tree[MIX],n,result[15000]; 5 int lowbit(int x) {return x&(-x);} 6 int get_sum(int k) 7 { 8 int s=0; 9 while(k>0) { 10 s+=tree[k]; 11 k-=lowbit(k); 12 } 13 return s; 14 } 15 void modify(int pos,int a) 16 { 17 while(pos<=MIX) { 18 tree[pos]+=a; 19 pos+=lowbit(pos); 20 } 21 } 22 int main(void) 23 { 24 int i,x,y; 25 memset(tree,0,sizeof(tree)); 26 memset(result,0,sizeof(result)); 27 scanf("%d",&n); 28 for(i=0; i<n; ++i){ 29 scanf("%d %d",&x,&y); 30 ++x; 31 ++result[get_sum(x)]; 32 modify(x,1); 33 } 34 for(i=0; i<n; ++i) printf("%d\n",result[i]); 35 return 0; 36 }
下面再来看一题目poj2299,这题目的转化后的意思就是给你一组数字,然后要你求这组数中的逆序对有多少。有两种方法做这题,一种是归并排序变形,一种就是树状数组,归并排序的方法是算法导论上的习题,在CLRS总结上面有这里就不说了,只说树状数组的方法。如何数据范围不大的话,我们就可以直接定义tree[i]代表数字i的个数,然后从后往前的遍历,这样就可以知道每个数字排在它后面却比它小的数字有多少个了,累加就可以了。也可以从前往后遍历,不过这时候get_sum(i)的值代表的是a[i]前面小于等于a[i]的有多少个,i-get_sum(i)就是大于在a[i]前面并且大于他的数了。然而这里的数据范围很打,a[i]的值能取到999999999,我们不可能开一个这么大的数组,所以在这里我们要用离散化先处理数组,离散化的意思,就是把原来的值建立一个新的一一映射,使范围减少,建立一个紧凑的范围,减少空间,在范围大而数据数量相对比较少的情况下很使用,在这题中,例如99999999 1 123 1583,我们建立的一一映射就是4 1 2 3,然后按照这个新的映射关系和之前的做法一样。代码如下:
1 #include<stdio.h> 2 #include<memory.h> 3 #define MIX 500001 4 struct node { 5 int v,idx; 6 } a[MIX]; 7 int tree[MIX],f[MIX]; 8 int lowbit(int x) {return x&(-x);} 9 void modify(int x,int b) 10 { 11 while(x<=MIX) { 12 tree[x] += b; 13 x += lowbit(x); 14 } 15 } 16 int get_sum(int x) 17 { 18 int sum = 0; 19 while(x) { 20 sum += tree[x]; 21 x -= lowbit(x); 22 } 23 return sum; 24 } 25 26 void swap(struct node* a,struct node *b) 27 { 28 struct node temp = *a; *a = *b; *b = temp; 29 } 30 31 int med(struct node *a,int low,int hight) 32 { 33 int center = (low+hight)>>1; 34 if(a[center].v > a[hight].v) 35 swap(&a[center],&a[hight]); 36 if(a[low].v > a[hight].v) 37 swap(&a[low],&a[hight]); 38 if(a[low].v < a[center].v ) 39 swap(&a[low],&a[center]); 40 return a[low].v; 41 } 42 43 void myqsort(struct node *a,int low,int hight) 44 { 45 if(low<hight) { 46 int i = low, j = hight,x=med(a,low,hight); 47 struct node tmp; 48 for(;;) { 49 while (a[++i].v < x) ; 50 while (a[--j].v > x) ; 51 if(i<j) { 52 tmp = a[i]; a[i] = a[j]; a[j] = tmp; 53 } else { 54 tmp = a[low]; a[low] = a[j]; a[j] = tmp; 55 break; 56 } 57 } 58 myqsort(a,low,j-1); 59 myqsort(a,j+1,hight); 60 } 61 } 62 int main(void) 63 { 64 int n,i; 65 long long sum; 66 a[0].v = -999999; 67 while(scanf("%d",&n),n) { 68 memset(tree,0,sizeof(tree)); 69 for(i=1; i<=n; ++i) { 70 scanf("%d",&a[i].v); 71 a[i].idx = i; 72 } 73 myqsort(a,1,n); 74 sum = 0; 75 for(i=1; i<=n; ++i) { 76 f[a[i].idx] = i; 77 for(i=n; i; --i) { 78 sum += get_sum(f[i]); 79 modify(f[i],1); 80 } 81 printf("%lld\n",sum); 82 } 83 }
上面两题都是单点更新,区间求和的,下面来看看树状数组是怎么区间更新,单点求值。Hdu1556,题意就是说总共有N个数,每次给个一个区间[a,b]给你,区间内的数全部+1,N次之后,要求输出每一个位置上的值。朴素的方法是每次遍历区间,+1,这样的复杂度是O(testcases*N*N),想不超时都难。这题的概述第一时间想到的是线段树,不过也可以用树状数组,甚至不用树状数组,直接用数组。这里要思考的是有没有一种修改方法不用遍历区间呢,不遍历就能达成遍历的效果,就相当于我修改一个值,后面的值也会受到影响,貌似用求和的思想可以达成影响后面的值,那么就是说,假设num[a]到num[b]要+1,只需要我们想一下tree[i]的定义,我们定义tree[i]代表对[i,N]的贡献,那么每当[a,b]要+1的时候,我们就可以tree[a]+=1,这是从[b+1,N]都多加了1,所以要tree[b+1]-=1。所以当要输出位置i加了多少次,就是get_sum(i),这里是抽象的来看,如果微观的看的话,要注意tree是怎么加的,不是累加,所以不会变多。在这题里是全部都输入之后,再从头输出i的值,所以就直接开一个数组就行了。可是如果是输入中间夹杂着多次查询的话,就要用树状数组了,很数组差不多[a,b]要加1,就modify(a,1),modify(b+1,-1)。下面是树状数组的代码:
1 #include<stdio.h> 2 #include<memory.h> 3 #define MIX 100001 4 int tree[MIX]; 5 void modify(int i,int c) 6 { 7 while(i<=MIX) { 8 tree[i] += c; 9 i += i & -i; 10 } 11 } 12 int get_sum(int i) 13 { 14 int sum = 0; 15 while(i) { 16 sum += tree[i]; 17 i -= i& -i; 18 } 19 return sum; 20 } 21 int main(void) 22 { 23 int n,a,b,i; 24 while(scanf("%d",&n),n) { 25 memset(tree,0,sizeof(tree)); 26 for(i=0; i<n; ++i) { 27 scanf("%d%d",&a,&b); 28 modify(a,1); 29 modify(b+1,-1); 30 } 31 for(i=1; i<n; ++i) 32 printf("%d ",get_sum(i)); 33 printf("%d\n",get_sum(n)); 34 } 35 return 0; 36 }
一维的树状数组大概就是这样了,现在说二维的树状数组。二维树状数组对应的是矩阵,是一维的扩展,一般用来快速求子矩阵的和,在二维树状数组中,tree[x][y]代表的是左上角是(x-lowbit(x)+1,y-lowbit(y)+1),右下角是(x,y)的矩阵的和。很明显求左上角是(1,1),右下角是(x,y)的求和就是二重循环枚举x,y,一个个子矩阵的叫上去。代码:
1 int get_sum(int x,int y) 2 { 3 int sum=0,y1; 4 while(x) { 5 y1 = y; 6 while(y1) { 7 sum += tree[x][y1]; 8 y1 -= y1 & -y1; 9 } 10 x -= x & -x; 11 } 12 }
modify函数也是差不多的
1 void modify(int x,int y,int val) 2 { 3 while(x<=MAX_X) { 4 int y1 = y; 5 while(y1<=MAX_Y) { 6 tree[x][y1] += val; 7 y1 += y1 & -y1; 8 } 9 x += x & -x; 10 } 11 }
查询和修改的时间复杂度是O(logMAX_X * logMAX_Y)。下面就看看怎么用了,poj2215,题目大意是一个N*N矩阵,初始0,有两个操作,一个是C x1 y1 x2 y2,就是把左上角是(x1,y1)右下角是(x2,y2)的子矩阵的每一位取反(0变1,1变0).这样和上面那题其实是差不多的,我们不用真的记录矩阵的真实值,只记录变化了多少次就可以了。因为一开始是0,所以就是说变化的次数是偶数就是0,是奇数就是1.tree[x][y]的数值代表(x,y)到(n,n)导致了多少变化,和上题一样modify时会把不应该变的也变了,所以还要变回来。代码如下:
1 #include<stdio.h> 2 #include<string.h> 3 #define N 1000 4 int tree[N+1][N+1],n; 5 void modify(int x,int y,int val) 6 { 7 while(x<=n) { 8 int y1 = y; 9 while(y1<=n) { 10 tree[x][y1] += val; 11 y1 += y1&-y1; 12 } 13 x += x&-x; 14 } 15 } 16 int get_sum(int x,int y) 17 { 18 int sum = 0,y1; 19 while(x) { 20 y1 = y; 21 while(y1) { 22 sum += tree[x][y1]; 23 y1 -= y1&-y1; 24 } 25 x -= x&-x; 26 } 27 return sum; 28 } 29 int main(void) 30 { 31 int T,X,x1,y1,x2,y2; 32 char c; 33 scanf("%d",&X); 34 getchar(); 35 while(X--) { 36 memset(tree,0,sizeof(tree)); 37 scanf("%d%d",&n,&T); 38 getchar(); 39 while(T--) { 40 scanf("%c %d %d",&c,&x1,&y1); 41 getchar(); 42 if(c=='C') { 43 scanf("%d %d",&x2,&y2); 44 getchar(); 45 modify(x1,y1,1); 46 modify(x1,y2+1,1); 47 modify(x2+1,y1,1); 48 modify(x2+1,y2+1,1); 49 } else 50 printf("%d\n",get_sum(x1,y1)&1); 51 } 52 putchar('\n'); 53 } 54 }
HDU4267,12年长春网络赛的题,是挺好的一题,题目大意是先给你一组数,然后有两个操作,一个是1 a b k c,意思是将[a,b]内符合(i-a)%k=0的位置都加上c,另一个操作是2 a,意思是查询位置a的值是多少。这题是区间内离散位置的更新,然后是单点求值。而无论是线段树还是树状数组的区间更新都是连续的更新,这里是离散,所以肯定一棵树是解决不了的。这里我们用线段树来解决这个问题,先看看有多少情况,1<=k<=10,k有10种情况,mod k有10种情况,那么总共就有100种情况了,那么我们就维护100个树状数组,每次只更新一个,查询的时候,就把相应的加起来就是了。所以我们一样用区间更新,单点取值的方法来做这题目,也就是get_sum(i)代表i位置上的值。那么就是tree[k][x%k][x],这里其实我们可以令q=(k-1)*10+x%k,也能分离出各种不同情况,只需要开一个的二维的就能代替原来的3维了。代码如下:
1 #include<stdio.h> 2 #include<memory.h> 3 #define MIX 50000 4 int n,num[MIX+1]; 5 int tree[100][MIX+1]; 6 void modify(int k,int i,int val) 7 { 8 while(i<=n) { 9 tree[k][i] += val; 10 i += i&-i; 11 } 12 } 13 14 int get_sum(int k,int i) 15 { 16 int sum = 0; 17 while(i) { 18 sum+=tree[k][i]; 19 i -= i&-i; 20 } 21 return sum; 22 } 23 24 int main(void) 25 { 26 int i,Q,a,b,k,c,p; 27 while(scanf("%d",&n)!=EOF) { 28 memset(tree,0,sizeof(tree)); 29 for(i=1; i<=n; ++i) scanf("%d",num+i); 30 scanf("%d",&Q); 31 while(Q--) { 32 scanf("%d",&p); 33 if(p==2) { 34 scanf("%d",&p); 35 a = num[p]; 36 for(i=1; i<=10; ++i) 37 a += get_sum((i-1)*10+p%i,p); 38 printf("%d\n",a); 39 } else { 40 scanf("%d%d%d%d",&a,&b,&k,&c); 41 b -= (b-a)%k; 42 modify((k-1)*10+a%k,a,c); 43 modify((k-1)*10+b%k,b+1,-c); 44 } 45 } 46 } 47 }
这个方法的空间占用比较大,因为每棵树都有大量的空间是浪费的,不会怎么用到的,按照k的范围,我们只开k课树,在第k课树更新的时候[a,b]的时候,会把区间里面不应该变化的点也变化了,如果我们把这些离散的点,映射成一个个连续的点就可以解决问题了,在第k课树,把1,1+k,1+2k...; 2,2+k,2+2k...等等都连续的放在一起,那么更新的时候就不会把不该更新的也更新了。那么映射就是1~1,1+k~2,1+2k~3...建立映射的代码如下:
1 void init() 2 { 3 int k,i,j,s; 4 for(k=1; k<=10; ++k) { 5 s = 1; 6 for(i=1;i<=k; ++i) { 7 for(j=i; j<=MIX; j+=k) 8 f[k][j] = s++; 9 } 10 } 11 }
注意,当应用了这个映射的方法之后,modify要上溯到MIX才行,因为原来的位置打乱了。
现在来总结一下树状数组
作用在统计求和,根据这个和代表的东西不同,灵活应用来解决问题。求和和更新的时间复杂度都是O(logn)。
通常有两种用法:1、单点更新,区间求和
此时tree[i]代表[i-i&-i,i]的和,get_sum(i)代表[1,i]的和
2、区间更新,单点求值 多个modify,把多加的减回去。
此时tree[i]表示[i-i&-i,i]对[i,MIX]的贡献,get_sum(i)代表i位置上的值
对于二维的树状数组,和一维的差不多,只不过是用于矩阵的求和。
———————————————————————以下是补充———————————————————————————————
树状数组也可以做到区间修改和区间查询。
给区间[l, r]同时加上x,令:
s(i) = 加上x之前的sum{a[1..i]}
s`(i) = 加上x之后的sum{a[1..i]}
那么,有:
where i < l → s`(i) = s(i)
where l ≤ i ≤ r →s`(i) = s(i) + x * (i - l + 1) = s(i) + x * i - x * (l - 1)
where r < i → s`(i) = s(i) + x * (r - l + 1)
令sum(bit, i)为树状数组bit的前 i 项和。构建两个数组bit0和bit1,并设:
sum{a[1..i]} = sum(bit1, i) * i + sum(bit0, i)
那么,要给[l, r]同时加上x,那么有:
在bit0的l位置加上-x(l-1)
在bit1的l位置加上x
在bit0的r+1位置加上xr
在bit1的r+1位置加上-x
便能在O(logn)实现对树状数组的更新和查询操作。
然后我们来看一道题,POJ3468 A Simple Problem with Integers。
题目大意:给n个数,q个询问,每次给一个区间加上同一个值,或者询问一个区间和。
然后就是这个树状数组的裸题咯,直接上代码吧。
虽然修改的时候,c*r不会爆int,但是读入a数组的时候,a*i会爆int请注意……
代码(1985MS)(为何POJ的G++会比C++慢一倍……):
1 #include <iostream> 2 #include <cstdio> 3 #include <cstring> 4 #include <algorithm> 5 using namespace std; 6 typedef long long LL; 7 8 const int MAXN = 100010; 9 10 LL bit0[MAXN], bit1[MAXN]; 11 int n, q; 12 13 inline int lowbit(int x) { 14 return x & -x; 15 } 16 17 void modify(LL *bit, int k, LL val) { 18 while(k <= n) { 19 bit[k] += val; 20 k += lowbit(k); 21 } 22 } 23 24 LL get_sum(LL *bit, int k) { 25 LL ret = 0; 26 while(k) { 27 ret += bit[k]; 28 k -= lowbit(k); 29 } 30 return ret; 31 } 32 33 void modify(int l, int r, LL val) { 34 modify(bit0, l, - val * (l - 1)); 35 modify(bit1, l, val); 36 modify(bit0, r + 1, val * r); 37 modify(bit1, r + 1, -val); 38 } 39 40 LL get_sum(int l, int r) { 41 LL sum1 = get_sum(bit1, l - 1) * (l - 1) + get_sum(bit0, l - 1); 42 LL sum2 = get_sum(bit1, r) * r + get_sum(bit0, r); 43 return sum2 - sum1; 44 } 45 46 int main() { 47 int l, r, a; 48 scanf("%d%d", &n, &q); 49 for(int i = 1; i <= n; ++i) 50 scanf("%d", &a), modify(i, i, a); 51 while(q--) { 52 char c; 53 scanf(" %c", &c); 54 if(c == 'C') { 55 scanf("%d%d%d", &l, &r, &a); 56 modify(l, r, a); 57 } else { 58 scanf("%d%d", &l, &r); 59 printf("%I64d\n", get_sum(l, r)); 60 } 61 } 62 }