树状数组(Binary Indexed Tree)
记录
17:30 2024-2-4
1.树状数组
树状数组(Binary Indexed Tree)可以完成以下操作
- 查询前缀和
- 增加单个元素a[i]的值
lowbit运算
lowbit(n) 定义为 非负整数n在二进制表示下 “最低位的1及其后边所有的0”
lowbit(n) = n & (~n+1) = n & (-n)
比如 10(1010) lowbit(10) = 2(0010)
记得补码的公式吗 找到最后一个不为0的数 保证这位数不变 前面的位取反
1111 0010 的补码为 0000 1110
这公式(lowbit(n))成立的原因就是因为 数 与 他的补码 求并后 剩下的就是最后一个不为0的数了
点击查看代码
// [1, n]
int bit[MAX_N + 1], n;
//查询前缀和
int sum(int i) {
int s = 0;
while (i > 0) {
s += bit[i];
i -= i & -i;
}
return s;
}
//单点增加
void add(int i, int x) {
while (i <= n) {
bit[i] += x;
i += i & -i;
}
}
1.区间增加 + 单点查询
树状数组仅仅支持单点增加 和 区间查询
将区间增加 + 单点查询 转化为 单点增加 + 区间查询
利用一个数组b,用来单点增加,并且利用树状数组记录前缀和,前缀和b[1~i] + a[i] 就是经过操作后的值
将”C l r d" 转化为
- 把b[l]加上d
- 把b[r+1]减去d (区间增加 -> 单点增加)
查询的话变为 b[1~i] + a[i] (单点查询 -> 区间查询)
点击查看代码
// N个数据 Q个操作
int N, Q;
int b[MAX_N + 1], a[MAX_N + 1];
// 操作是什么 区间增加 或 查询
char T[MAX_Q];
// L R 表示区间 X 表示区间增加时候的值
int L[MAX_Q], R[MAX_Q], X[MAX_Q];
int ask(int i) {
int s = 0;
while (i > 0) {
s += b[i];
i -= i & -i;
}
return s;
}
void add(int i, int x) {
while (i <= n) {
b[i] += x;
i += i & -i;
}
}
void solve() {
for (int i = 0; i < Q; i++) {
if (T[i] == 'C' ) {
add(L[i], X[i]);
add(R[i] + 1, -X[i]);
} else {
printf("%lld\n" , ask(i) + a[i]);
}
}
}
2.区间增加 + 区间查询
把区间增加转变为单点增加,利用两个树状数组\(c_0 和 c_1\)
将”C l r d" 转化为
- 在树状数组\(c_0\)中,把位置l上的数加d
- 在树状数组\(c_0\)中,把位置r + 1上的数减d
- 在树状数组\(c_1\)中,把位置l上的数加l * d
- 在树状数组\(c_1\)中,把位置r + 1上的数减(r + 1) * d
建立sum存储a的原始前缀和
将“Q l r” 转化为 1~r 和 1~l-1两部分进行相减
$ (sum[r] + (r + 1) * ask(c_0, r) - ask(c_1, r)) - (sum[l - 1] + l * ask(c_0, l - 1) - ask(c_1, l - 1)) $
点击查看代码
#include<iostream>
#include<cstdio>
using namespace std;
typedef long long ll;
const int MAXN = 100005;
int N, Q;
ll a[MAXN], sum[MAXN];
ll c[2][MAXN];
// k表示是哪个树状数组,i表示位置, v表示加入的值
void add(int k, int i, int v) {
while (i <= N){
c[k][i] += v;
i += i & -i;
}
}
// k表示是哪个树状数组,i表示位置
ll ask(int k, int i) {
ll s = 0;
while (i > 0) {
s += c[k][i];
i -= i & -i;
}
return s;
}
int main() {
cin >> N >> Q;
for(int i = 1; i <= N; i++) {
scanf("%lld", &a[i]);
sum[i] = sum[i - 1] + a[i];
}
char c[2];
int l, r, v;
for(int i = 0; i < Q; i++) {
// %s 它会读入一个不含空格、TAB和回车符的字符串,存入字符数组
// %c 会读入\n
scanf("%s%d%d", c, &l, &r);
if(c[0] == 'C') {
scanf("%d", &v);
add(0, l, v);
add(0, r + 1, -v);
add(1, l, l * v);
add(1, r + 1, -(r + 1) * v);
} else {
ll result = (sum[r] + (r + 1) * ask(0, r) - ask(1, r))
- (sum[l - 1] + l * ask(0, l - 1) - ask(1, l - 1));
printf("%lld\n", result);
}
}
}