树状数组的扩展应用
「观前提醒」
「文章仅供学习和参考,如有问题请在评论区提出」
这里主要讲树状数组的各种扩展应用,至于树状数组的具体实现原理可以看下面的博客。
O(N) 建树#
对于树状数组最基本的建树方式,就是每个点加值。
时间复杂度:
代码实现
int tr[N]; // tr[] 存储树状数组数据
int a[N]; // a[] 存储原数组数据
int n; // 数列长度
int lowbit(int x) { return x & -x; }
void add(int x, c) {
for (int i = x; i <= n; x += lowbit(x))
tr[i] += c;
}
// 建树
void build() {
for (int i = 1; i <= n; i++)
add(i, a[i]);
}
对于
方法一#
我们知道对于树状数组
代码实现
int tr[N]; // 树状数组数据
int a[N]; // 原数组数据
int sum[N]; // sum[] 存储 a[] 的前缀和
int n; // 数列长度
int lowbit(int x) { return x & -x; }
// 建树
void build() {
// 求 a[] 的前缀和 sum[]
for (int i = 1; i <= n; i++)
sum[i] = sum[i - 1] + a[i];
// 利用前缀和求出区间和,O(N)建树
for (int i = 1; i <= n; i++)
tr[i] = sum[i] - sum[i - lowbit(i)];
}
方法二#

观察上图我们发现,对于
我们还知道,对于
用这种方式,同样也可以实现
代码实现
int tr[N]; // 树状数组数据
int a[N]; // 原数组数据
int n; // 数列长度
int lowbit(int x) { return x & -x; }
// 建树
void build() {
for (int i = 1; i <= n; i++) {
tr[i] += a[i];
int fa = i + lowbit(i); // 获得父节点下标
if (fa <= n) // 判断父节点是否超出数列范围
tr[fa] += tr[i];
}
}
维护区间和#
单点修改,区间查询#
给定一个长度为
的数列,要对数列进行 次以下两种操作:
1 x y
:将位置的数加上 (或者减去 、变成 、乘以 )。 2 x y
:查询区间的和。
这是树状数组最基本的用法。
时间复杂度
- 单点修改
- 区间查询
代码实现
int tr[N];
int a[N];
int n;
int lowbit(int x) { return x & -x; }
// 给 x 位置的数加上 c
void add(int x, int c) {
for (int i = x; i <= n; i += lowbit(i))
tr[i] += c;
}
// 查询 1 ~ x 的区间和
void query(int x) {
int res = 0;
for (int i = x; i; i -= lowbit(i))
res += tr[i];
return res;
}
// 使用
add(x, c); // 给 x 位置的数加上 c
add(x, y - (query(x) - query(x - 1))); // 讲 x 位置的数改为 y
int val1 = query(x); // 查询 [1, x] 的区间和
int val2 = query(r) - query(l - 1); // 查询 [l, r] 的区间和
int val3 = query(x) - query(x - 1); // 查询 x 位置的值
区间修改,单点查询#
给定一个长度为
的数列,要对数列进行 次以下两种操作:
1 x y k
:将区间里的数都加上 (或者都减去 )。 2 x
:查询位置的值
这里我们需要用到差分,从而利用树状数组来维护差分数组。
- 区间修改:
add(l, k), add(r + 1, -k);
- 单点查询:
query(y) - query(x - 1);
时间复杂度
- 区间修改:
- 单点查询:
代码实现
int tr[N];
int a[N];
int n;
int lowbit(int x) { return x & -x; }
// 给 x 位置的数加上 c
void add(int x, int c) {
for (int i = x; i <= n; i += lowbit(i)) tr[i] += c;
}
// 查询 1 ~ x 的区间和
void query(int x) {
int res = 0;
for (int i = x; i; i -= lowbit(i)) res += tr[i];
return res;
}
// 使用
add(r, c), add(l - 1, c); // 讲区间 [l, r] 都加上 c
int val = query(x) + a[x]; // 查询 x 位置的值
区间修改,区间查询#
给定一个长度为
的数列,要对数列进行 次以下两种操作:
1 x y k
:将区间里的数都加上 (或者都减去 )。 2 x y
:查询的区间和。
平时遇到这种问题,我们一般都会选择用线段树来解决,但是树状数组也能实现。
这里我们首先想到要用差分数组来实现,但是怎么才能查询区间和呢?
对于数列
如果我们对所列出的式子进行补充,变成一个矩阵,如下图所示。

如果我们根据列进行求和,那么前缀和的表示公式就能变形为,
这样我们就能把问题转化成维护
- 区间查询:获取前缀和,直接根据公式计算。
- 时间复杂度:
- 时间复杂度:
- 区间修改:分别对
和 所维护的前缀和做出相应的修改。- 时间复杂度:
- 对于
,执行add(x, k), add(y + 1, -k);
- 对于
,执行add(x, x * k), add(y + 1, (y + 1) * k);
- 时间复杂度:
代码实现
#define int long long
int tr1[N]; // 维护 b[i] 的前缀和
int tr2[N]; // 维护 i * b[i] 的前缀和
int a[N]; // 原数组
int n;
int lowbit(int x) { return x & -x; }
// 对树状数组 tr[] 执行加和操作
void add(int tr[], int x, int c) {
for (int i = x; i <= n; i += lowbit(i)) tr[i] += c;
}
// 对树状数组 tr[] 执行查询前缀和的操作
int query(int tr[], int x) {
int res = 0;
for (int i = x; i; i -= lowbit(i)) res += tr[i];
return res;
}
// 建树
void build() {
for (int i = 1; i <= n; i++) {
int b = a[i] - a[i - 1]; // 差分 b[i]
add(tr1, i, b);
add(tr2, i, i * b);
}
}
// 查询数列的前缀和
int pre_sum(int x) {
return query(tr1, x) * (x + 1) - query(tr2, x);
}
// 执行操作
// 建树(初始化)
build();
// 区间查询
int val = pre_sum(y) - pre_sum(x - 1); // [x, y] 的区间和
// 区间修改
add(tr1, x, k), add(tr1, y + 1, -k); // 修改 tr1[]
add(tr2, x, x * k), add(tr2, y + 1, (y + 1) * -k); // 修改 tr2[]
整合的维护区间和的完成代码,支持区间修改和区间查询(函数封装好)
#include <bits/stdc++.h> using namespace std; const int N = 1e5 + 10; #define int long long int tr1[N], tr2[N]; int a[N], n; int lowbit(int x) { return x & -x; } void add(int tr[], int x, int c) { for (int i = x; i <= n; i += lowbit(i)) tr[i] += c; } int query(int tr[], int x) { int res = 0; for (int i = x; i; i -= lowbit(i)) res += tr[i]; return res; } // 建树 void build() { for (int i = 1; i <= n; i++) { int b = a[i] - a[i - 1]; add(tr1, i, b); add(tr2, i, i * b); } } // 查询 [l, r] 的区间和 int sum(int l, int r) { int sum1 = query(tr1, r) * (r + 1) - query(tr2, r); int sum2 = query(tr1, l - 1) * l - query(tr2, l - 1); return sum1 - sum2; } // 将 [l, r] 里的数加 k void add(int l, int r, int k) { add(tr1, l, k), add(tr1, r + 1, -k); add(tr2, l, l * k), add(tr2, r + 1, (r + 1) * -k); } signed main() { int q; cin >> n >> q; for (int i = 1; i <= n; i++) scanf("%lld", &a[i]); build(); while (q--) { char op[2]; int l, r, k; scanf("%s", op); if (op[0] == 2) { scanf("%lld%lld", &l, &r); printf("%lld\n", sum(l, r)); } else { scanf("%lld%lld%lld", &l, &r, &k); add(l, r, k); } } return 0; }
维护二维子矩阵和(二维树状数组)#
单点修改,子矩阵查询#
给定一个
的矩阵 ,要对矩阵进行 次以下两种操作:
1 x y k
:将元素加上 (或者都减去 )。 2 a b c d
:查询左上角为,右上角为 的子矩阵内所有数的和。
二维树状数组就是树状数组套树状数组。就是在原先一维树状数组的基础上,用此树状数组的节点再来建立树状数组,从而实现维护矩阵和的功能。
我们思考树状数组的修改逻辑,就是当某一个节点被修改时,有多少的节点会被影响到,然后再修改这些被影响的节点。所以对于矩阵
一维树状数组的修改是
所以修改操作的时间复杂度为
而对于二维前缀和的初始化,有 sum[i][j] = sum[i - 1][j] + sum[i][j - 1] - sum[i - 1][j - 1] + a[i][j];
(不做具体解释,不会的可以先学一学,下面的也一样)。
同理,对于查询操作,我们知道通过二维前缀和来求子矩阵的式子为,Sum = sum[x2][y2] - sum[x1 - 1][y2] - sum[x2][y1 - 1] + sum[x1 - 1][y1 - 1];
。
那么只需要获取它们维护的前缀和,然后根据公式计算出结果,时间复杂度也为
那么这就是二维树状数组的基本逻辑,从而实现维护矩阵和的功能。
时间复杂度
-
初始化:
-
单点修改:
-
子矩阵查询:
代码实现
#define int long long
int tr[N][N]; // 二维树状数组
int a[N][N]; // 原数组
int n, m; // 行高和列宽
int lowbit(int x) { return x & -x; }
// 给 (x, y) 位置的数加上 c
void add(int x, int y, int c) {
for (int i = x; i <= n; i += lowbit(i))
for (int j = y; j <= m; j += lowbit(j))
tr[i][j] += c;
}
// 查询 (x, y) 位置的二维前缀和
int query(int x, int y) {
int res = 0;
for (int i = x; i; i -= lowbit(i))
for (int j = y; j; j -= lowbit(j))
res += tr[i][j];
return res;
}
// // 查询左上角为(x1, y1), 右下角为(x2, y2) 的子矩阵的和
int query(int x1, int y1, int x2, int y2) {
return query(x2, y2)
- query(x1 - 1, y2)
- query(x2, y1 - 1)
+ query(x1 - 1, y1 - 1);
}
// 使用
add(x, y, c); // 给 (x, y) 位置的数加上 c
add(x, y, -c); // 给 (x, y) 位置的数减去 c
int sum1 = query(x, y); // 查询左上角为(1, 1), 右下角为(x, y) 的子矩阵的和
int sum2 = query(a, b, c, d); // 查询左上角为(a, b), 右下角为(c, d) 的子矩阵的和
子矩阵修改,单点查询#
给定一个
的矩阵 ,要对矩阵进行 次以下两种操作:
1 a b c d k
:将左上角为,右上角为 的子矩阵里的每个元素都加上 (或者都减去 )。 2 x y
:询问元素的值。
和上面进行区间修改,单点查询的相同,这个是用一维树状数组来维护一维差分数组。那么同理,我们也可以用二维树状数组来维护二维差分数组。
对于二维差分数组,我们每次的矩阵修改操作为,b[x1][y1] += c, b[x2 + 1, y1] -= c, b[x1, y2 + 1] -= c, b[x2 + 1][y2 + 1] += c;
,每次的单点查询操作就是求一次二维前缀和。
时间复杂度
- 子矩阵修改:
- 单点查询:
代码实现
#define int long long
int tr[N][N]; // 二维树状数组
int a[N][N]; // 原数组
int n, m; // 行高和列宽
int lowbit(int x) { return x & -x; }
void add(int x, int y, int c) {
for (int i = x; i <= n; i += lowbit(i))
for (int j = y; j <= m; j += lowbit(j))
tr[i][j] += c;
}
void query(int x, int y) {
int res = 0;
for (int i = x; i; i -= lowbit(i))
for (int j = y; j; j -= lowbit(j))
res += tr[i][j];
return res;
}
// 将左上角为 (x1, y1), 右下角为 (x2, y2) 的子矩阵的每个元素都加上 c
void add(int x1, int y1, int x2, int y2, int c) {
add(x1, y1, c);
add(x2 + 1, y1, -c);
add(x1, y2 + 1, -c);
add(x2 + 1, y2 + 1, c);
}
// 使用
add(x1, y1, x2, y2, c); // 将左上角为 (x1, y1), 右下角为 (x2, y2) 的子矩阵的每个元素都加上 c
int val = query(x, y) + a[x][y]; // 查询 (x, y) 位置的元素值
子矩阵修改,子矩阵查询#
给定一个
的矩阵 ,要对矩阵进行 次以下两种操作:
1 a b c d k
:将左上角为,右上角为 的子矩阵里的每个元素都加上 (或者都减去 )。 2 a b c d
:查询左上角为,右上角为 的子矩阵内所有数的和。
我们可以像上面处理一维区间和那样思考,通过维护二维前缀和数组来解决问题。
具体思路和推导过程就不赘述了,要想了解的可以看这篇博客:数据结构学习笔记-二维树状数组 - 知乎
具体想法是用四个二维树状数组来分别维护
然后通过推导出来的公式来计算前缀和,
代码实现
#define int long long
int a[N][N], b[N][N], c[N][N], d[N][N]; // 二维树状数组
int n, m;
int lowbit(int x) { return x & -x; }
void add(int x, int y, int v) {
for (int i = x; i <= n; i += lowbit(i)) {
for (int j = y; j <= m; j += lowbit(j)) {
a[i][j] += v;
b[i][j] += (x - 1) * v;
c[i][j] += (y - 1) * v;
d[i][j] += (x - 1) * (y - 1) * v;
}
}
}
int query(int x, int y) {
int res = 0;
for (int i = x; i; i -= lowbit(i)) {
for (int j = y; j; j -= lowbit(j)) {
res += x * y * a[i][j]
- y * b[i][j]
- x * c[i][j]
+ d[i][j];
}
}
return res;
}
// 将左上角为 (x1, y1), 右上角 (x2, y2) 的子矩阵的所有元素加上 c
void add(int x1, int y1, int x2, int y2, int c) {
add(x1, y1, c);
add(x1, y2 + 1, -c);
add(x2 + 1, y1, -c);
add(x2 + 1, y2 + 1, c);
}
// 查询左上角为 (x1, y1), 右上角 (x2, y2) 的子矩阵的元素和
int query(int x1, int y1, int x2, int y2) {
return query(x2, y2)
- query(x1 - 1, y2)
- query(x2, y1 - 1)
+ query(x1 - 1, y1 - 1);
}
// 使用
add(x1, y1, x2, y2, c); // 将左上角为 (x1, y1), 右上角 (x2, y2) 的子矩阵的所有元素加上 c
int sum = query(x1, y1, x2, y2);// 查询左上角为 (x1, y1), 右上角 (x2, y2) 的子矩阵的元素和
求逆序对个数#
给定一个长度为
的数列,求其中逆序对的个数。 逆序对:对于
,有 。
归并排序是可以求一个数列中逆序对的个数的,时间复杂度为
对于逆序对个数的求解,树状数组是通过求每个
从







这样我们遍历
还有,这样的做法是把
如果这样的话,树状数组的时间和空间消耗相对于归并排序都会更多点(虽然总的时空复杂度是相同的)。其实这样就体现出了归并排序求逆序对的好处,它并不用考虑
代码实现
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1e5 + 10;
int tr[N];
int a[N];
int n;
int lowbit(int x) { return x & -x; }
void add(int x, int c) {
for (int i = x; i <= n; i += lowbit(i)) tr[i] += c;
}
int query(int x) {
int res = 0;
for (int i = x; i; i -= lowbit(i)) res += tr[i];
return res;
}
int main() {
cin >> n;
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
LL res = 0;
for (int i = 1; i <= n; i++) {
add(a[i], 1);
// 求逆序对的个数
res += i - query(a[i]);
}
cout << res << "\n";
return 0;
}
需要离散化操作的代码
#include <bits/stdc++.h> using namespace std; typedef long long LL; const int N = 1e5 + 10; int tr[N]; // 树状数组 LL a[N]; // 原数组 int L[N]; // 离散化后的序列 int n; int lowbit(int x) { return x & -x; } void add(int x, int c) { for (int i = x; i <= n; i += lowbit(i)) tr[i] += c; } int query(int x) { int res = 0; for (int i = x; i; i -= lowbit(i)) res += tr[i]; return res; } // 离散化,时间复杂度 O(NlogN) void Unique() { vector<LL> t; for (int i = 1; i <= n; i++) t.push_back(a[i]); sort(t.begin(), t.end()); t.erase(unique(t.begin(), t.end()), t.end()); for (int i = 1; i <= n; i++) L[i] = lower_bound(t.begin(), t.end(), a[i]) - t.begin() + 1; } int main() { cin >> n; for (int i = 1; i <= n; i++) scanf("%lld", &a[i]); // 离散化 Unique(); LL res = 0; for (int i = 1; i <= n; i++) { add(L[i], 1); res += i - query(L[i]); } cout << res << "\n"; return 0; }
求数列中小于 x 的元素个数#
根据上面求逆序对的思路,我们可以求出数列中小于(大于、小于或等于、大于或等于)
同样的,如果数列中有负数或者数很大,就还得需要
这里注意,这种方法只支持离线查询,预处理的时间复杂度为
代码实现
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10;
int tr[N];
int a[N];
int n;
int lowbit(int x) { return x & -x; }
void add(int x, int c) {
for (int i = x; i <= n; i += lowbit(i)) tr[i] += c;
}
int query(int x) {
int res = 0;
for (int i = x; i; i -= lowbit(i)) res += tr[i];
return res;
}
int main() {
cin >> n;
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
// 预处理
for (int i = 1; i <= n; i++)
add(a[i], 1);
// 查询
int x;
cin >> x;
int num1 = query(x - 1); // 查询小于 x 的元素个数
int num2 = query(x); // 查询小于等于 x 的元素个数
int num3 = n - query(x); // 查询大于 x 的元素个数
int num4 = n - query(x - 1);// 查询大于等于 x 的元素个数
return 0;
}
需要离散化操作的代码
#include <bits/stdc++.h> using namespace std; typedef long long LL; const int N = 1e5 + 10; int tr[N]; // 树状数组 LL a[N]; // 原数组 int L[N]; // 离散化后的数列 int n; int lowbit(int x) { return x & -x; } void add(int x, int c) { for (int i = x; i <= n; i += lowbit(i)) tr[i] += c; } int query(int x) { int res = 0; for (int i = x; i; i -= lowbit(i)) res += tr[i]; return res; } // 离散化 void Unique() { vector<LL> t; for (int i = 1; i <= n; i++) t.push_back(a[i]); sort(t.begin(), t.end()); t.erase(unique(t.begin(), t.end()), t.end()); for (int i = 1; i <= n; i++) L[i] = lower_bound(t.begin(), t.end(), a[i]) - t.begin() + 1; } int main() { cin >> n; for (int i = 1; i <= n; i++) scanf("%lld", &a[i]); // 离散化 Unique(); // 预处理 for (int i = 1; i <= n; i++) add(L[i], 1); // 查询 int x; cin >> x; int num1 = query(x - 1); // 查询小于 x 的元素个数 int num2 = query(x); // 查询小于等于 x 的元素个数 int num3 = n - query(x); // 查询大于 x 的元素个数 int num4 = n - query(x - 1);// 查询大于等于 x 的元素个数 return 0; }
参考资料#
树状数组 - OI Wiki:https://oi-wiki.org/ds/fenwick/
树状数组O(n)建树 荼白777的博客-CSDN博客:https://blog.csdn.net/weixin_45724872/article/details/120110911
算法学习笔记(2) : 树状数组 - 知乎:https://zhuanlan.zhihu.com/p/93795692
数据结构学习笔记-二维树状数组 - 知乎:https://zhuanlan.zhihu.com/p/571255016
作者:Oneway
出处:https://www.cnblogs.com/oneway10101/p/17587242.html
版权:本作品采用「署名-非商业性使用-相同方式共享 4.0 国际」许可协议进行许可。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人