树状数组从入门到某个奇怪的水准
一维树状数组
树状数组基本原理
我们借 OI-wiki 上的图用用:
原理是对于每个节点,存储一段连续的长度为 \(2\) 的次幂的区间。
而具体来说,对于一个数 \(x\) ,其管理 \(lowbit(x)\) 长度的区间,而这具体是那个区间也可以写出: \((x-lowbit(x),x]\)。
然后还有一些性质:
性质1:
对于一个节点 \(x\) 来说 \(x+lowbit(x)\) 为其父节点。
证明:
简单证一下,发现 \(x+lowbit(x)\) 实际上是将 \(x\) 的最低位抹除了,并且往高位进位。然后直接进入到最近的满的整幂次位,树形意义上就是进入父亲。
Q.E.D.
lowbit
然后我们看如何快速求得 \(lowbit\) 。
其实很好做, \(lowbit(x)=x\&-x\)。这是为什么?
实际上是利用了补码:负数的反码是原码符号位不动,其他位按位取反。补码是在其反码的基础上 \(+1\) ,那么它会连续进位直到遇到第一个 \(1\) 。高位和原来完全不同,因此有 \(x\&-x\) 为 \(lowbit\) 。
单点修改区间查询
最经典的树状数组。需要注意的是树状数组可以存储的信息需要有区间可减性或者只需要查询从 \(1\) 开始的,否则不能用树状数组。因为其区间查询是通过两次前缀查询做差分得到的。
单点修改简单,只需要先找到要修改的位置,然后直接跳父亲修改即可。修改的位置就是其本身。注意下标不能开到 \(0\),应该说不能对 \(0\) 位置进行修改操作,不然会死循环。可以通过下标整体 \(+1\) 之类的操作解决。
区间查询稍微复杂一些。
我们将查询的前缀分为若干段。具体来说,我们把前缀用尽可能大的树状数组上的区间覆盖。对于我们要查询的前缀,先可以统计 \((x-lowbit(x),x]\) 部分的答案,然后直接 \(x-=lowbit(x)\) ,达到一个较小的子任务,如此循环求解即可。
要询问一段区间就类似使用前缀和的使用树状数组即可。
#define lowbit(x) (x&-x)
struct BitTree{
ll c[N];
inline void add(int x,int k){
for(;x<=n;x+=lowbit(x) ) c[x]+=k;
}
inline ll find(int y){
ll res=0;
for(;y;y-=lowbit(y) ) res+=c[y];
return res;
}
}t;
//使用例:
int op;read(op);
if(op==1){
int i,x;
read(i,x);
t.add(i,x);
}else{
int l,r;
read(l,r);
printf("%lld\n",t.find(r)-t.find(l-1) );
}
树状数组上二分
给一段序列 \(a\) 和一个数 \(k\),求最小的 \(d\) 使得 \(\sum_{i\le d}a_i\le k\) 。
这个用线段树可以线段树上二分做到 \(O(n\log n)\),在此不再赘述。当然也可以二分然后使用树状数组 \(check\) 做到 \(O(n\log^2 n)\),尽管它理论复杂度比较高,但是实际上会比线段树二分快(可见线段树的巨大常数)。
我们还不满意!其实这个使用树状数组上二分也是可以达到 \(O(n\log n)\) 的,再加上其巨小常数,直接快到起飞。
如果我们每次直接将 \(mid\) 拉到一个整的区间 \([l,mid]\),并且维护 \(l\) 左边的前缀和(或者直接 \(k\) 每次减掉这个前缀和),就可以做到每次 \(check\) 达到 \(O(1)\) 的复杂度,以到达最后总复杂 \(O(n\log n)\)。怎么操作?
我们把 \(c\) 数组总大小开大到 \(2\) 的幂次,这样每次二分到的 \(mid\) 就都是 \([l,mid]\) 的整幂次区间,正常二分即可。
所以实际上离散化后,其也可以支持维护普通平衡树。(这不是把权值线段树吊打了)
普通平衡树 参考代码:
#include<bits/stdc++.h>
#define ll long long
#define db double
#define filein(a) freopen(#a".in","r",stdin)
#define fileot(a) freopen(#a".out","w",stdout)
#define sky fflush(stdout);
#define gc getchar
#define pc putchar
namespace IO{
inline bool blank(const char &c){
return c==' ' or c=='\n' or c=='\t' or c=='\r' or c==EOF;
}
inline void gs(char *s){
char ch=gc();
while(blank(ch) ) {ch=gc();}
while(!blank(ch) ) {*s++=ch;ch=gc();}
*s=0;
}
inline void gs(std::string &s){
char ch=gc();s+='#';
while(blank(ch) ) {ch=gc();}
while(!blank(ch) ) {s+=ch;ch=gc();}
}
inline void ps(char *s){
while(*s!=0) pc(*s++);
}
inline void ps(const std::string &s){
for(auto it:s)
if(it!='#') pc(it);
}
template<class T>
inline void read(T &s){
s=0;char ch=gc();bool f=0;
while(ch<'0'||'9'<ch) {if(ch=='-') f=1;ch=gc();}
while('0'<=ch&&ch<='9') {s=s*10+(ch^48);ch=gc();}
if(ch=='.'){
db p=0.1;ch=gc();
while('0'<=ch&&ch<='9') {s=s+p*(ch^48);p*=0.1;ch=gc();}
}
s=f?-s:s;
}
template<class T,class ...A>
inline void read(T &s,A &...a){
read(s);read(a...);
}
};
using IO::read;
using IO::gs;
using IO::ps;
const int N=1e5+3;
#define lowbit(x) (x&-x)
struct BitTree{
int n;
int c[(1<<17)+3];
inline void add(int x,int k){
for(;x<=n;x+=lowbit(x) ) c[x]+=k;
}
inline int find(int y){
int res=0;
for(;y;y-=lowbit(y) ) res+=c[y];
return res;
}
inline int kth(int k){
int l=1,r=n;
while(l<=r){
int mid=(l+r)>>1;
if(c[mid]<k){
k-=c[mid];
l=mid+1;
}else{
r=mid-1;
}
}
return l;
}
inline int rk(int x){
return find(x-1)+1;
}
inline int pre(int x){
return kth(find(x-1) );
}
inline int nxt(int x){
return kth(find(x)+1);
}
}t;
int Q;
struct ques{
int op,x;
}qu[N];
int a[N],c[N];
namespace Discrete{
int id[N];
inline void work(){
int tot=0;
std::sort(id+1,id+1+Q,[](int x,int y){
return qu[x].x<qu[y].x;
});
int la=-1e9;
int top=0;
for(int i=1;i<=Q;++i){
if(qu[id[i] ].op==4) continue;
if(qu[id[i] ].x!=la) ++top;
la=qu[id[i] ].x;c[top]=qu[id[i] ].x;
qu[id[i] ].x=top;
}
}
};
int main(){
filein(a);fileot(a);
read(Q);
for(int i=1;i<=Q;++i){
Discrete::id[i]=i;
read(qu[i].op,qu[i].x);
}
Discrete::work();
t.n=1;
while(t.n<Q) t.n<<=1;
for(int i=1;i<=Q;++i){
int op=qu[i].op,x=qu[i].x;
if(op==1){
t.add(x,1);
}else if(op==2){
t.add(x,-1);
}else if(op==3){
printf("%d\n",t.rk(x) );
}else if(op==4){
printf("%d\n",c[t.kth(x)]);
}else if(op==5){
printf("%d\n",c[t.kth(t.find(x-1) )]);
}else if(op==6){
printf("%d\n",c[t.nxt(x)]);
}
}
return 0;
}
还有一种实现是不需要扩大至 \(2\) 的整数幂次的。我们倍增从大到小跳整数幂次即可。
时间上几乎没有差异,但是优化了空间。
#include<bits/stdc++.h>
#define ll long long
#define db double
#define filein(a) freopen(#a".in","r",stdin)
#define fileot(a) freopen(#a".out","w",stdout)
#define sky fflush(stdout);
#define gc getchar
#define pc putchar
namespace IO{
inline bool blank(const char &c){
return c==' ' or c=='\n' or c=='\t' or c=='\r' or c==EOF;
}
inline void gs(char *s){
char ch=gc();
while(blank(ch) ) {ch=gc();}
while(!blank(ch) ) {*s++=ch;ch=gc();}
*s=0;
}
inline void gs(std::string &s){
char ch=gc();s+='#';
while(blank(ch) ) {ch=gc();}
while(!blank(ch) ) {s+=ch;ch=gc();}
}
inline void ps(char *s){
while(*s!=0) pc(*s++);
}
inline void ps(const std::string &s){
for(auto it:s)
if(it!='#') pc(it);
}
template<class T>
inline void read(T &s){
s=0;char ch=gc();bool f=0;
while(ch<'0'||'9'<ch) {if(ch=='-') f=1;ch=gc();}
while('0'<=ch&&ch<='9') {s=s*10+(ch^48);ch=gc();}
if(ch=='.'){
db p=0.1;ch=gc();
while('0'<=ch&&ch<='9') {s=s+p*(ch^48);p*=0.1;ch=gc();}
}
s=f?-s:s;
}
template<class T,class ...A>
inline void read(T &s,A &...a){
read(s);read(a...);
}
};
using IO::read;
using IO::gs;
using IO::ps;
const int N=1e5+3;
#define lowbit(x) (x&-x)
struct BitTree{
int n;
int c[N];
inline void add(int x,int k){
for(;x<=n;x+=lowbit(x) ) c[x]+=k;
}
inline int find(int y){
int res=0;
for(;y;y-=lowbit(y) ) res+=c[y];
return res;
}
inline int kth(int k){
int p=0;
for(int i=log2(n);i>=0;--i){
int step=1<<i;
if(p+step<=n and c[p+step]<k){
p+=step;
k-=c[p];
}
}
return p+1;
}
inline int rk(int x){
return find(x-1)+1;
}
inline int pre(int x){
return kth(find(x-1) );
}
inline int nxt(int x){
return kth(find(x)+1);
}
}t;
int Q;
struct ques{
int op,x;
}qu[N];
int a[N],c[N];
namespace Discrete{
int id[N];
inline void work(){
int tot=0;
std::sort(id+1,id+1+Q,[](int x,int y){
return qu[x].x<qu[y].x;
});
int la=-1e9;
int top=0;
for(int i=1;i<=Q;++i){
if(qu[id[i] ].op==4) continue;
if(qu[id[i] ].x!=la) ++top;
la=qu[id[i] ].x;c[top]=qu[id[i] ].x;
qu[id[i] ].x=top;
}
}
};
int main(){
filein(a);fileot(a);
read(Q);
for(int i=1;i<=Q;++i){
Discrete::id[i]=i;
read(qu[i].op,qu[i].x);
}
Discrete::work();
t.n=Q;
for(int i=1;i<=Q;++i){
int op=qu[i].op,x=qu[i].x;
if(op==1){
t.add(x,1);
}else if(op==2){
t.add(x,-1);
}else if(op==3){
printf("%d\n",t.rk(x) );
}else if(op==4){
printf("%d\n",c[t.kth(x)]);
}else if(op==5){
printf("%d\n",c[t.kth(t.find(x-1) )]);
}else if(op==6){
printf("%d\n",c[t.nxt(x)]);
}
}
return 0;
}
区间修改单点查询
我们记 \(b[i]=a[i]-a[i-1]\) ,那么 \(a[i]=\sum_{j=1}^i b[i]\)。
所以我们把差分数组用树状数组维护的话,就可以达到目的了。
修改区间 \([l,r]\) 加 \(k\) 时,只需修改 \(l\) 位置 \(+k\),\(r+1\) 位置 \(-k\) 即可。
#define lowbit(x) (x&-x)
struct BitTree{
ll c[N];
inline void add(int x,int k){
for(;x<=n;x+=lowbit(x) ) c[x]+=k;
}
inline ll find(int y){
ll res=0;
for(;y;y-=lowbit(y) ) res+=c[y];
return res;
}
inline void modify(int l,int r,int k){
add(l,k);add(r+1,-k);
}
}t;
//使用例:
int op;read(op);
if(op==1){
int l,r,k;
read(l,r,k);
t.modify(l,r,k);
}else{
int p;
read(p);
printf("%lld\n",t.find(p) );
}
区间修改区间查询
推个式子:
那么我们只要维护两个树状数组分别维护 \(b_i\) 的前缀和,和 \(b_i\times i\) 的前缀和。(当然也可以放在一起维护)
#define lowbit(x) (x&-x)
struct BitTree{
struct node{
ll c1,c2;
}c[N];
inline void add(int x,int k){
for(int i=x;i<=n;i+=lowbit(i) ){
c[i].c1+=k;
c[i].c2+=1ll*x*k;
}
}
inline void modify(int l,int r,int k){
add(l,k);add(r+1,-k);
}
inline ll find(int y){
ll res1=0,res2=0;
for(int i=y;i;i-=lowbit(i) ){
res1+=c[i].c1;
res2+=c[i].c2;
}
return 1ll*(y+1)*res1-res2;
}
inline ll query(int l,int r){
return find(r)-find(l-1);
}
}t;
//使用例:
int op;read(op);
if(op==1){
int l,r,k;
read(l,r,k);
t.modify(l,r,k);
}else{
int l,r;
read(l,r);
printf("%lld\n",t.query(l,r) );
}
小小优化
数组数组的O(n)建树
直接 \(1\) 到 \(n\) 跑一遍,每次更新直接父亲,这一路下来所有节点都更新完了。如果把上界扩大到了 \(2\) 的幂次就多更新一点而已。
for(int i=1;i<=n;++i){
int x;read(x);
t.c[i]+=x;
int j=i+lowbit(i);
if(j<=n) t.c[j]+=t.c[i];
}
时间戳优化
我们对于每个节点打一个标记,记录其上一次使用的时间(数据组数)。然后发现和目前的不一样就清空。这样就不需要每组数据都要暴力清空。(树状数组这种每个父节点与儿子节点不实时绑定更新的就可以这样,还有类似的像平衡树啊,trie啊,AC自动机啊也都可以,线段树就不行。判断方法就是比如说看你写的 pushup 里面是\(t[x].cnt=t[lc(x)].cnt+t[rc(x)].cnt+1\) 还是没有这个 \(+1\) ,没有 \(+1\) 的就不行)
#define lowbit(x) (x&-x)
int Tag;
struct BitTree{
struct node{
int c,tag;
}c[N];
inline void add(int x,int k){
for(;x<=n;x+=lowbit(x) ){
if(c[x].tag!=Tag) c[x].c=0;
c[x].c+=k;c[x].tag=Tag;
}
}
inline ll find(int y){
ll res=0;
for(;y;y-=lowbit(y) ){
if(c[y].tag==Tag) res+=c[y].c;
}
return res;
}
}t;
二维树状数组
一定要记得 \(add\) 操作两层的范围,别再写成两层 \(n\) 了!
单点修改区间查询
就是每一行维护一个树状数组,然后对于这些行再统一维护一个树状数组(而这个树状数组中的元素是行的树状数组)。简单地扩展到二维即可。
#define lowbit(x) (x&-x)
struct BitTree2D{
ll c[N][N];
inline void add(int x,int y,int k){
for(int i=x;i<=n;i+=lowbit(i) ){
for(int j=y;j<=m;j+=lowbit(j) ){
c[i][j]+=k;
}
}
}
inline ll find(int x,int y){
ll res=0;
for(int i=x;i;i-=lowbit(i) ){
for(int j=y;j;j-=lowbit(j) ){
res+=c[i][j];
}
}
return res;
}
inline ll query(int sx,int sy,int fx,int fy){
return find(fx,fy)-find(fx,sy-1)-find(sx-1,fy)+find(sx-1,sy-1);
}
}t;
区间修改单点查询
考虑类似一维树状数组的差分做法,用树状数组维护差分数组 \(b[i][j]=a[i][j]-a[i-1][j]-a[i][j-1]+a[i-1][j-1]\)。二维的差分就是二维前缀和的差分,这个也很简单,不多赘述。
#define lowbit(x) (x&-x)
struct BitTree2D{
ll c[N][N];
inline void add(int x,int y,int k){
for(int i=x;i<=n;i+=lowbit(i) ){
for(int j=y;j<=m;j+=lowbit(j) ){
c[i][j]+=k;
}
}
}
inline void modify(int sx,int sy,int fx,int fy,int k){
add(sx,sy,k);add(fx+1,fy+1,k);
add(sx,fy+1,-k);add(fx+1,sy,-k);
}
inline ll find(int x,int y){
ll res=0;
for(int i=x;i;i-=lowbit(i) ){
for(int j=y;j;j-=lowbit(j) ){
res+=c[i][j];
}
}
return res;
}
}t;
区间修改区间查询
还是推个式子:
四个值 \(b[i][j],b[i][j]\times i,b[i][j]\times j,b[i][j]\times ij\) 都维护一下即可。
#define lowbit(x) (x&-x)
struct BitTree2D{
struct node{
ll c1,c2,c3,c4;
inline void inc(int x,int y,int k){
c1+=k;c4+=1ll*x*y*k;
c2+=1ll*x*k;c3+=1ll*y*k;
}
}c[N][N];
inline void add(int x,int y,int k){
for(int i=x;i<=n;i+=lowbit(i) ){
for(int j=y;j<=m;j+=lowbit(j) ){
c[i][j].inc(x,y,k);
}
}
}
inline ll find(int x,int y){
ll res1=0,res2=0,res3=0,res4=0;
for(int i=x;i;i-=lowbit(i) ){
for(int j=y;j;j-=lowbit(j) ){
res1+=c[i][j].c1;res2+=c[i][j].c2;
res3+=c[i][j].c3;res4+=c[i][j].c4;
}
}
return res1*(x+1)*(y+1)-res2*(y+1)-res3*(x+1)+res4;
}
inline void modify(int sx,int sy,int fx,int fy,int k){
add(sx,sy,k);add(fx+1,fy+1,k);
add(sx,fy+1,-k);add(fx+1,sy,-k);
}
inline ll query(int sx,int sy,int fx,int fy){
return find(fx,fy)-find(sx-1,fy)-find(fx,sy-1)+find(sx-1,sy-1);
}
}t;