树状数组Ⅰ

树状数组很有用,虽然是不如线段树那么强大,但是它写起来简单得不得了,而且它的复杂度是赤裸裸的\(O(\log n)\),没有讨厌的常数!

盗张图来说明树状数组的工作原理(图片来自oi-wiki):

树状数组的功能为单点修改区间查询,或者区间修改单点查询。

懒得废话上模板:

  • 单点修改区间查询:

struct tree{
	#define lowbit(wh) (wh&-wh)
   int c[N],m;
	void change(int wh,int num){
		for(;wh<=m;wh+=lowbit(wh))c[wh]+=num;
	}
	int work_sum(int wh){
		int ans=0;
		for(;wh;wh-=lowbit(wh))ans+=c[wh];
		return ans;
	}
};

其中change函数可以做到在\(O(\log n)\)的复杂度下给第wh个元素加上num,work_sum函数可以做到在\(O(\log n)\)的复杂度下求出前wh个元素的和(包括第wh个元素。)

  • 区间修改单点查询(与差分结合):

struct tree{
	#define lowbit(wh) (wh&-wh)
   int c[N],m;
	void change1(int wh,int num){
		for(;wh<=m;wh+=lowbit(wh))c[wh]+=num;
	}
	void change2(int l,int r,int num){
		change1(l,num);
		change1(r+1,-num);
	}
	int work(int wh){
		int ans=0;
		for(;wh;wh-=lowbit(wh))ans+=c[wh];
		return ans;
	}
};

其中change2函数可以做到\(O(\log n)\)的复杂度下给第l个和第r个元素之间(包含两端)的所有元素加上num,而work函数可以在\(O(\log n)\)的复杂度下求出第wh个元素的值。


树状数组的作用

  • 快速求区间和,不说了都是模板。

  • 求逆元或者变相逆元,因为是树状数组可以快速求出比当前数小(或者大)的数的个数,也可以在\(O(n\log n)\)的复杂度下求逆元以及一些更加复杂的问题。

比如 三元上升子序列,归并排序就无法解决了,就要用到树状数组。

思路就是暴力,每处理一个数就扫描前面的数,遇到一个比它小的数就统计比这个数小的数的个数,而这个“比这个数小的个数”可以用树状数组维护,所以复杂度就是\(O(n^2\log^2 n)\)

它依然可以优化,因为“比这个数小的数”所有的答案区间会有很严重的重叠,所以可以另开一个树状数组r,它动态更新,维护当前(假如已经处理到第now个元素了)满足\(a_i<a_j<a_{now}\)\(i<j<now\)的个数。

然后复杂度就可以降到\(O(n\log^2 n)\),(因为每个元素要查找离散化之后的位置有个log的二分查找复杂度)

代码(注意离散化):

#include<cstdio>
#include<algorithm>
#include<cstring>
#define int long long
#define sc(wh) scanf("%lld",&wh)
using namespace std;
const int N=30010;

int m,a[N],b[N],c[N];
int find(int wh){
	int l=1,r=m,mid,an;
	while(l<=r){
		mid=l+r>>1;
		if(a[mid]<wh)l=mid+1;
		else an=mid,r=mid-1;
	}
	return an;
}

struct tree{
	#define lowbit(wh) (wh&-wh)
	int d[N],num;
	void change(int wh,int val){
		for(;wh<=num;wh+=lowbit(wh))d[wh]+=val;
	}
	int worksum(int wh){
		int an=0;
		for(;wh;wh-=lowbit(wh))an+=d[wh];
		return an;
	}
	int allsum(int l,int r){
		if(r<l)return 0;
		return worksum(r)-worksum(l-1);
	}
}t,r;
int ni[N],sum[N];

signed main(){
	
	sc(m);
	for(int i=1;i<=m;i++){
		sc(a[i]);
		b[i]=a[i];
	}
	sort(a+1,a+m+1);
	for(int i=1;i<=m;i++){
		c[i]=find(b[i]);
	}
	t.num=m;
	int ans=0;
	for(int i=1;i<=m;i++){
		ni[i]=t.allsum(1,c[i]-1);
		t.change(c[i],1);
	}
	
	r.num=m;
	for(int i=1;i<=m;i++){
		ans+=r.allsum(1,c[i]-1);
        r.change(c[i],ni[i]);
	}
	printf("%lld",ans);
	
	return 0;
}
posted @ 2021-08-14 19:46  Feyn618  阅读(35)  评论(0编辑  收藏  举报