树状数组学习笔记与总结
树状数组
OI Wiki
信息学奥赛一本通
例题
单点修改,区间查询
我的代码
点击查看代码
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 1e6 + 10;
int n, q;
int w[N], tr[N];
int lowbit(int x) {
return x & -x;
}
void add(int x, int k) {
while (x <= n) {
tr[x] += k;
x += lowbit(x);
}
}
int getsum(int x) {
int sum = 0;
while (x > 0) {
sum += tr[x];
x -= lowbit(x);
}
return sum;
}
int getsum(int l, int r) {
return getsum(r) - getsum(l - 1);
}
void init() {
for (int i = 1; i <= n; i ++) {
tr[i] += w[i];
int j = i + lowbit(i);
if (j <= n)
tr[j] += tr[i];
}
}
signed main() {
scanf("%lld%lld", &n, &q);
for (int i = 1; i <= n; i ++)
scanf("%lld", w + i);
init();
int op, x, y;
while (q --) {
scanf("%lld%lld%lld", &op, &x, &y);
if (op == 1)
add(x, y);
else
printf("%lld\n", getsum(x, y));
}
}
区间修改,单点查询
我的代码
点击查看代码
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 1e6 + 10;
int n, q;
int a[N], d[N], tr[N];
int lowbit(int x) {
return x & -x;
}
void add(int x, int k) {
while (x <= n) {
tr[x] += k;
x += lowbit(x);
}
}
void add(int l, int r, int k) {
add(l, k), add(r + 1, -k);
}
int getsum(int x) {
int sum = 0;
while (x > 0) {
sum += tr[x];
x -= lowbit(x);
}
return sum;
}
void init() {
for (int i = 1; i <= n; i ++) {
tr[i] += d[i];
int j = i + lowbit(i);
if (j <= n)
tr[j] += tr[i];
}
}
signed main() {
scanf("%lld%lld", &n, &q);
for (int i = 1; i <= n; i ++)
scanf("%lld", a + i), d[i] = a[i] - a[i - 1];
init();
int op, l, r, x;
while (q --) {
scanf("%lld%lld", &op, &l);
if (op == 1)
scanf("%lld%lld", &r, &x), add(l, r, x);
else
printf("%lld\n", getsum(l));
}
}
区间修改,区间查询
我的代码
点击查看代码
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 1e6 + 10;
int n, q;
int a[N], d[N], t1[N], t2[N];
int lowbit(int x) {
return x & -x;
}
void add(int x, int k) {
int v = x * k;
while (x <= n) {
t1[x] += k, t2[x] += v;
x += lowbit(x);
}
}
void add(int l, int r, int k) {
add(l, k), add(r + 1, -k);
}
int getsum(int *tr, int x) {
int sum = 0;
while (x > 0) {
sum += tr[x];
x -= lowbit(x);
}
return sum;
}
int getsum(int x) {
return getsum(t1, x) * (x + 1) - getsum(t2, x);
}
int getsum(int l, int r) {
return getsum(r) - getsum(l - 1);
}
void init() {
for (int i = 1; i <= n; i ++) {
t1[i] += d[i], t2[i] += d[i] * i;
int j = i + lowbit(i);
if (j <= n)
t1[j] += t1[i], t2[j] += t2[i];
}
}
signed main() {
scanf("%lld%lld", &n, &q);
for (int i = 1; i <= n; i ++)
scanf("%lld", a + i), d[i] = a[i] - a[i - 1];
init();
int op, l, r, x;
while (q --) {
scanf("%lld%lld%lld", &op, &l, &r);
if (op == 1)
scanf("%lld", &x), add(l, r, x);
else
printf("%lld\n", getsum(l, r));
}
}