树状数组Ⅰ
树状数组很有用,虽然是不如线段树那么强大,但是它写起来简单得不得了,而且它的复杂度是赤裸裸的\(O(\log n)\),没有讨厌的常数!
盗张图来说明树状数组的工作原理(图片来自oi-wiki):
树状数组的功能为单点修改区间查询,或者区间修改单点查询。
懒得废话上模板:
-
单点修改区间查询:
struct tree{
#define lowbit(wh) (wh&-wh)
int c[N],m;
void change(int wh,int num){
for(;wh<=m;wh+=lowbit(wh))c[wh]+=num;
}
int work_sum(int wh){
int ans=0;
for(;wh;wh-=lowbit(wh))ans+=c[wh];
return ans;
}
};
其中change函数可以做到在\(O(\log n)\)的复杂度下给第wh个元素加上num,work_sum函数可以做到在\(O(\log n)\)的复杂度下求出前wh个元素的和(包括第wh个元素。)
-
区间修改单点查询(与差分结合):
struct tree{
#define lowbit(wh) (wh&-wh)
int c[N],m;
void change1(int wh,int num){
for(;wh<=m;wh+=lowbit(wh))c[wh]+=num;
}
void change2(int l,int r,int num){
change1(l,num);
change1(r+1,-num);
}
int work(int wh){
int ans=0;
for(;wh;wh-=lowbit(wh))ans+=c[wh];
return ans;
}
};
其中change2函数可以做到\(O(\log n)\)的复杂度下给第l个和第r个元素之间(包含两端)的所有元素加上num,而work函数可以在\(O(\log n)\)的复杂度下求出第wh个元素的值。
树状数组的作用
-
快速求区间和,不说了都是模板。
-
求逆元或者变相逆元,因为是树状数组可以快速求出比当前数小(或者大)的数的个数,也可以在\(O(n\log n)\)的复杂度下求逆元以及一些更加复杂的问题。
比如 三元上升子序列,归并排序就无法解决了,就要用到树状数组。
思路就是暴力,每处理一个数就扫描前面的数,遇到一个比它小的数就统计比这个数小的数的个数,而这个“比这个数小的个数”可以用树状数组维护,所以复杂度就是\(O(n^2\log^2 n)\)。
它依然可以优化,因为“比这个数小的数”所有的答案区间会有很严重的重叠,所以可以另开一个树状数组r,它动态更新,维护当前(假如已经处理到第now个元素了)满足\(a_i<a_j<a_{now}\)且\(i<j<now\)的个数。
然后复杂度就可以降到\(O(n\log^2 n)\),(因为每个元素要查找离散化之后的位置有个log的二分查找复杂度)
代码(注意离散化):
#include<cstdio>
#include<algorithm>
#include<cstring>
#define int long long
#define sc(wh) scanf("%lld",&wh)
using namespace std;
const int N=30010;
int m,a[N],b[N],c[N];
int find(int wh){
int l=1,r=m,mid,an;
while(l<=r){
mid=l+r>>1;
if(a[mid]<wh)l=mid+1;
else an=mid,r=mid-1;
}
return an;
}
struct tree{
#define lowbit(wh) (wh&-wh)
int d[N],num;
void change(int wh,int val){
for(;wh<=num;wh+=lowbit(wh))d[wh]+=val;
}
int worksum(int wh){
int an=0;
for(;wh;wh-=lowbit(wh))an+=d[wh];
return an;
}
int allsum(int l,int r){
if(r<l)return 0;
return worksum(r)-worksum(l-1);
}
}t,r;
int ni[N],sum[N];
signed main(){
sc(m);
for(int i=1;i<=m;i++){
sc(a[i]);
b[i]=a[i];
}
sort(a+1,a+m+1);
for(int i=1;i<=m;i++){
c[i]=find(b[i]);
}
t.num=m;
int ans=0;
for(int i=1;i<=m;i++){
ni[i]=t.allsum(1,c[i]-1);
t.change(c[i],1);
}
r.num=m;
for(int i=1;i<=m;i++){
ans+=r.allsum(1,c[i]-1);
r.change(c[i],ni[i]);
}
printf("%lld",ans);
return 0;
}
一如既往,万事胜意