二叉索引树(树状数组)
二叉索引树(Binary Indexed Tree),又叫树状数组,主要是用于解决动态连续和查询问题。
给定一个n个元素的数组A1,A2,....,An,你的任务是设计一个数据结构,支持以下两种操作。
- Add(x,d)操作:让Ax增加d。
- Query(L,R):计算AL+AL+1+...AR。
对于正整数X,我们 定义lowbit(x)为想的二进制表达式中最右边所对应的值。比如由38288的二进制是1001010110010000可知lowbit(38288) = 16。我们将从0到n-1这n个数取lowbit,将数值相同的节点放在同一行上就可以得到一棵BIT树,而且lowbit值越大,越靠近根。
根据整数补码表示的原理得到计算lowbit的方法是
int lowbit(int x) { return x & -x; }
对于每一个结点i,如果它是左子结点,那么它的父节点编号是i+lowbit(i);如果它是右子结点,那么它的父节点编号是i-lowbit(i)。这样我们可以构造一个辅助数组C,其中
Ci = Ai - lowbit(i) + 1 + Ai - lowbit(i) + 2 +...+ Ai
也即C的每个元素都是A数组中的一段连续和。
有了C数组之后怎么计算前缀和呢?
有了它和父结点的关系,我们可以顺着结点i往左走,将沿途的C累加起来即可。具体实现如下:
int sum(int x) { int res = 0; while(x > 0) { res += C[x]; x -= lowbit(x); } return res; }
对于第一种操作又怎么更新C数组中的元素呢?
同样有了它和父节点的关系,修改了一个Ai,我们从Ci开始往左走,沿途修改所有节点对应的Ci即可。具体实现如下:
int add(int x, int d) { while(x <= n) { C[x] += d; x += lowbit(x); } }
可以看出两个操作的时间复杂度均为O(logn)。预处理时将A和C数组清空,然后执行n次add操作,总时间复杂度为O(n*logn)。
直接看一道例题1:LA4329 Ping pong
题意是有n个人,每个人都有一个技能值,问能够举行多少场比赛,组成比赛的条件是三个人,两个选手,一个裁判,这个裁判必须住在两个选手的中间,而且技能值必须也在两者之间。
乍一看好像不是直接动态区间求和的题,我们仔细分析一下,设a1到ai-1有ci个小于ai,那么就有i - 1 - ci个大于ai的数字,同样道理设ai到an有di个比ai小,就有n-i-di个比ai大。然后根据乘法原理总的方法数等于ci * (n - i - di) + di * (i - 1 - ci)。现在问题的关键是怎么求解ci和di。
其实我们将ai作为一个下标,每次将这个位置上的值修改为1,然后查询Cai-1,也就是前i - 1项的和,就是前面有几个比ai小的数。类似的倒序遍历ai查询,得到bi。最后一个循环计算所有的比赛数。通过本题可以知道,树状数组用于快速计算前缀和,此处用于快速计算有多少个比某值小的数的个数,可以说是一种妙用吧。
1 #include <cstdio> 2 #include <cstring> 3 #include <algorithm> 4 #include <vector> 5 using namespace std; 6 7 const int maxn = 20000 + 10; 8 int n, a[maxn], c[maxn], d[maxn]; 9 10 struct BIT { 11 int n; 12 vector<int> C; 13 14 void resize(int n) { 15 this->n = n; 16 C.resize(n); 17 } 18 void clear() { 19 fill(C.begin(), C.end(), 0); 20 } 21 int lb(int k) { 22 return k & -k; 23 } 24 25 void add(int x, int d) { 26 while(x <= n) { 27 C[x] += d; 28 x += lb(x); 29 } 30 } 31 32 int sum(int x) { 33 int res = 0; 34 while(x > 0) { 35 res += C[x]; 36 x -= lb(x); 37 } 38 return res; 39 } 40 }f; 41 int main() 42 { 43 int T; 44 scanf("%d", &T); 45 while(T--) { 46 scanf("%d", &n); 47 int maxa = 0; 48 for(int i = 1; i <= n; i++) { 49 scanf("%d", &a[i]); 50 maxa = max(maxa, a[i]); 51 } 52 53 f.resize(maxa); 54 f.clear(); 55 for(int i = 1; i <= n; i++) { 56 f.add(a[i], 1); 57 c[i] = f.sum(a[i] - 1); 58 } 59 f.clear(); 60 for(int i = n; i >= 1; i--) { 61 f.add(a[i], 1); 62 d[i] = f.sum(a[i] - 1); 63 } 64 65 long long ans = 0; 66 for(int i = 1; i <= n; i++) { 67 ans += (long long)c[i] * (n - i - d[i]) + (long long)d[i] * (i - 1 - c[i]); 68 } 69 printf("%lld\n", ans); 70 } 71 return 0; 72 }
例题2:HDU 1166 敌兵布阵
题意很简单,动态查询连续和。
1 #include <cstdio> 2 #include <cstring> 3 4 const int maxn = 50000 + 10; 5 int n, s[maxn]; 6 int lb(int k) { 7 return k & (-k); 8 } 9 10 void add(int k, int v) { 11 while(k <= n) { 12 s[k] += v; 13 k += lb(k); 14 } 15 } 16 17 int sum(int k) { 18 int res = 0; 19 while(k > 0) { 20 res += s[k]; 21 k -= lb(k); 22 } 23 return res; 24 } 25 26 int main() 27 { 28 int T; 29 char op[10]; 30 int ca = 1; 31 scanf("%d", &T); 32 while(T--) { 33 scanf("%d", &n); 34 memset(s, 0, sizeof(s)); 35 for(int i = 1; i <= n; i++) { 36 int j; 37 scanf("%d", &j); 38 add(i, j); 39 } 40 printf("Case %d:\n", ca++); 41 while(scanf("%s", op), op[0] != 'E') { 42 int i, j; 43 scanf("%d%d", &i, &j); 44 if(op[0] == 'Q') { 45 printf("%d\n", sum(j) - sum(i - 1)); 46 }else if(op[0] == 'A') { 47 add(i, j); 48 }else if(op[0] == 'S') { 49 add(i, -j); 50 } 51 } 52 } 53 return 0; 54 }
例题3:HDU 1394 Minimum Inversion Number
给出一个n的全排列,问通过循环移动组成的n个序列中形成的逆序对数最小是多少。
先看怎么计算一个序列的逆序对数,从序列的第一个数开始,假如该数字是3,n=10,易得该数是第7大的数,我们使用树状数组查询之前加入的(虽然之前加入了0个数)数的前7项和,就是3的逆序对数。不要忘记的是查询后在树状数组中标记。这样依次计算后可以的到整个序列的逆序对数。
如果对每一个序列都这么做,时间复杂度是O(n2logn),时间是不允许的。由前面的计算我们知道了第一个序列的逆序对数,思考能不能从第一个逆序对数中得出后面序列的逆序对数。
答案是可以的。第二组序列可以看成是第一个数字放到了数列的末尾,首先我们可以按照之前的方法,计算出前n-num[i]大的和,就是新增加的逆序对数,那减少了多少呢?
其实就是移动到末尾的数字的大小。
1 #include <cstdio> 2 #include <cstring> 3 4 const int maxn = 5000 + 10; 5 int n, s[maxn], num[maxn]; 6 7 int lb(int k) { 8 return k & (-k); 9 } 10 11 void add(int k, int v) { 12 while(k <= n) { 13 s[k] += v; 14 k += lb(k); 15 } 16 } 17 18 int sum(int k) { 19 int res = 0; 20 while(k > 0) { 21 res += s[k]; 22 k -= lb(k); 23 } 24 return res; 25 } 26 27 int main() 28 { 29 while(scanf("%d", &n) != EOF) { 30 memset(s, 0, sizeof(s)); 31 for(int i = 0; i < n; i++) { 32 scanf("%d", &num[i]); 33 } 34 35 int tm = 0; 36 for(int i = 0; i < n; i++) { 37 tm += sum(n - num[i]); 38 add(n - num[i], 1); 39 } 40 int ans = tm; 41 42 for(int i = 0; i < n - 1; i++) { 43 int tp = tm; 44 add(n - num[i], -1); 45 tp += sum(n - num[i]) - num[i]; 46 add(n - num[i], 1); 47 if(ans > tp) 48 ans = tp; 49 tm = tp; 50 } 51 printf("%d\n", ans); 52 } 53 return 0; 54 }