树状数组学习笔记
虽然好像是复习???算是重新再学一遍好了
数据结构原理与实现
其实树状数组感觉挺妙的,我想好久硬是想不出当时想出来的人怎么想出来的。后来一想感觉其实有点像二进制优化,通过二进制下连续 \(0\) 的长度来表示储存的信息 (说了等于没说) ,所以树状数组维护的信息必须符合一定的结合律。然后就没有然后了。
复杂度证明
因为一个数 \(n\) 的二进制有 \(log(n)\) 位,所以复杂度是 \(nlog(n)\)
题目
「模板」树状数组1
板子题,树状数组维护前缀和,然后区间查询的时候左边界和右边界减一下就行
#include<bits/stdc++.h>
using namespace std;
#define R register int
const int MAXN = 500000 + 10;
int c[MAXN], n, m;
int inline lowbit(int x) {
return x & -x;
}
void inline update(int x,int k){
while(x <= n){
c[x] += k;
x += lowbit(x);
}
return;
}
int inline sum(int x) {
int sun = 0;
while(x > 0) {
sun += c[x];
x -= lowbit(x);
}
return sun;
}
int main() {
scanf("%d%d", &n, &m);
for(R i = 1; i <= n; i ++) {
int x; scanf("%d", &x);
update(i, x);
}
for(R i = 1; i <= m; i ++){
int x, y ,tag;
scanf("%d%d%d", &tag, &x, &y);
if(tag == 1) update(x,y);
else printf("%d\n", sum(y) - sum(x - 1));
}
return 0;
}
「模板」树状数组2
板子题,利用差分的思想来进行区间修改
#include<bits/stdc++.h>
using namespace std;
#define R register int
const int MAXN = 5000000 + 10;
int a[MAXN], c[MAXN], p, n, m;
inline int lowbit(int t) {return t&(-t);}
inline void update(int i, int x) {
while(i <= n){
c[i] += x;
i += lowbit(i);
}
return;
}
inline long long sum(int x) {
long long ans=0;
while(x > 0){
ans += c[x];
x -= lowbit(x);
}
printf("%lld\n",ans);
}
int main()
{
scanf("%d%d",&n,&m);
for(R i = 1; i <= n; i ++){
scanf("%d", &a[i]);
update(i, a[i] - a[i - 1]);
}
for(R i = 1; i <= m; i ++){
scanf("%d", &p);
if(p == 1){
int x, y, k;
scanf("%d%d%d", &x, &y, &k);
update(x, k); update(y + 1, -k );
}
else{
long long x;
scanf("%lld", &x);
sum(x);
}
}
return 0;
}
[逆序对](https://www.luogu.com.cn/problem/P1908)
逆序对是树状数组的经典求解对象,先对一维进行排序,对另一维直接一个个加入树状数组,然后统计前面出现了多少个数即可
#include <bits/stdc++.h>
using namespace std;
struct node {
long long coin;
int num;
}a[500000<<1];
long long c[500000<<1], b[500000<<1], n;
long long lowbit (long long x) {return x & -x;}
bool cmp (node x, node y) {
if(x.coin==y.coin) return x.num>y.num;
return x.coin>y.coin;
}
long long sum (long long i) {
long long ans = 0;
while (i > 0) {
ans += c[i];
i -= lowbit(i);
}
return ans;
}
void update (long long i, long long x) {
while (i <= n) {
c[i] += x;
i += lowbit(i);
}
}
int main () {
long long ans=0; scanf("%lld", &n);
for(register int i = 1 ; i <= n ; i ++) {
scanf("%lld", &a[i].coin); a[i].num = i;
}
sort(a + 1, a + n + 1, cmp);
for(register int i = 1 ; i <= n ; i ++) b[a[i].num] = i;
for(register int i = 1 ; i <= n ; i ++) {
update (b[i], 1); ans += sum (b[i] - 1);
}
printf("%lld\n",ans);
return 0;
}