浅谈树状数组套主席树
话说主席树还没写就先写这一篇了\(qwq\)
回顾一下主席树的实现过程:类似查分思想,将线段树的每次修改看做函数式以支持可持久化。因为这样的线段树是可减的。
那么我们维护信息的时候,就要维护每一次新形成的信息。但是我们可以根据前一个信息的基础上进行改动,而不必要去再建一棵树。
所以总而言之,是前缀和的思想。
那么,当需要修改的时候,怎么做呢?
考虑普通的区间操作,当做单点修改的时候,一般用树状数组,线段树和分块。最好实现的就是树状数组。
考虑用树状数组来维护主席树的信息。
树状数组中维护了每一次加入一个数的新形成的主席树的信息。
对于修改一个值,考虑树状数组的单点修改,若修改点\(p\),则把一路上的\(p+lowbit(p)\)全部修改即可,这样是\(O(log)\)的,主席树修改是\(O(log)\)的,即总复杂度\((\log^2{n})\).
查询的时候,我们可以用树状数组提前将第\(r\)棵树和第\(l-1\)棵树的信息预处理出来,是\(log\)的,然后在树上二分,跳左右子树的时候,这\(log\)个信息也一起跳。大概复杂度也是\(O(\log^2{n})\)的。
总时间复杂度:\(O(n\log^2{n})\)
代码:
#include<bits/stdc++.h>
using namespace std;
const int MAXN=1e6+10;
int n,m,a[MAXN],u[MAXN],x[MAXN];
int l[MAXN],r[MAXN],k[MAXN],cur;
int cur1,cur2,q1[MAXN],q2[MAXN],v[MAXN];
char op[MAXN];
set<int>ST;
map<int,int>mp;
struct SGT{
int cur,rt[MAXN<<2],sum[MAXN<<5],lc[MAXN<<5],rc[MAXN<<5];
void build(int &o){o=++cur;}
void print(int o,int l,int r){
if(!o)return;
if(l==r&&sum[o])printf("%d",l);
int mid=l+r>>1;
print(lc[o],l,mid);
print(rc[o],mid+1,r);
}
void update(int &o,int l,int r,int x,int v){
if(!o)o=++cur;
sum[o]+=v;
if(l==r)return ;
int mid=l+r>>1;
if(x<=mid)update(lc[o],l,mid,x,v);
else update(rc[o],mid+1,r,x,v);
}
}st;
inline int lbt(int x){return x&(-x);}
void upd(int o,int x,int v){
for(;o<=n;o+=lbt(o))st.update(st.rt[o],1,n,x,v);
}
void gtv(int o,int *A,int &p){
p=0;
//A数组维护了树状数组o所控制的所有根信息
for(;o;o-=lbt(o))A[++p]=st.rt[o];
}
int query(int l,int r,int k){
if(l==r)return l;
int mid=l+r>>1,siz=0;
for(int i=1;i<=cur1;++i)siz+=st.sum[st.lc[q1[i]]];
//q1是r树的信息
for(int i=1;i<=cur2;++i)siz-=st.sum[st.lc[q2[i]]];
//q2是l-1树的信息,通过gtv处理
//siz代表信息的区间[l,r]信息
if(siz>=k){
for(int i=1;i<=cur1;++i)q1[i]=st.lc[q1[i]];
for(int i=1;i<=cur2;++i)q2[i]=st.lc[q2[i]];
return query(l,mid,k);//q1q2一样改,走左子树
}
else{
for(int i=1;i<=cur1;++i)q1[i]=st.rc[q1[i]];
for(int i=1;i<=cur2;++i)q2[i]=st.rc[q2[i]];
return query(mid+1,r,k-siz);
}
}
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;++i)scanf("%d",a+i),ST.insert(a[i]);
for(int i=1;i<=m;++i){
scanf(" %c",op+i);
if(op[i]=='C')scanf("%d%d",u+i,x+i),ST.insert(x[i]);
else scanf("%d%d%d",l+i,r+i,k+i);
}
for(set<int>::iterator it=ST.begin();it!=ST.end();++it)
mp[*it]=++cur,v[cur]=*it;//unique
for(int i=1;i<=n;++i)a[i]=mp[a[i]];
for(int i=1;i<=m;++i)if(op[i]=='C')x[i]=mp[x[i]];
n+=m;
for(int i=1;i<=n;++i)upd(i,a[i],1);
for(int i=1;i<=m;++i){
if(op[i]=='C'){
upd(u[i],a[u[i]],-1);
upd(u[i],x[i],1);
a[u[i]]=x[i];
}
else{
gtv(r[i],q1,cur1);//预处理
gtv(l[i]-1,q2,cur2);//预处理
printf("%d\n",v[query(1,n,k[i])]);
}
}
return 0;
}
代码中的\(set\)是用来去重的,\(map\)是用来离散化的。
注意询问的字符输入,不要漏了前面的空格。
例题:二逼平衡树
就一树套树,但是树状数组套主席树比线段树套\(Splay\)好写多了\(qwq\)
思考,既然和\(dynamic\) \(rankings\)那个题差不多,就是多了一些操作,所以一样的操作查询和修改。
我们提前把要改的值都加了进去,提前弄好了所有可能出现的新版本,后面只管查询了。
考虑:查询\(x\)在区间\([l,r]\)的排名。
显然,我们知道区间\([l,r]\),那我们把\(x\)扔进去二分就行了。
因为主席树(或者说动态开点权值线段树)维护的是前缀信息,所以——
我们先二分查出来区间\([1,r]\)的\(x\)排名,再二分查出区间\([1,l-1]\)的排名,相减就行了。
注意的是,这里的排名要加一个1.
那么,\(1,2,3\)操作就这样解决了,考虑剩下俩操作。
考虑前驱:我们查出\(x\)的当前排名,看它排名\(rk\)前面是谁就行了,再查询排名为\(rk-1\)的数。
考虑后继:由于不太好判无解,我们就查出\(x+1\)的排名,看它是不是大于区间长度——即存不存在,存在则查询这个排名对应的值即可。
最后细节:对于所有涉及区间中值的询问,都需要离散化。之后要注意数组中维护的是对应值的排名,最后输出也要注意。后继查询的是\(x+1\),以及注意两种无解情况各自输出什么。
一个是\(2147483647\),一个是\(-2147483647\),别用\(INT\)_\(MIN\),那是\(-2147483648\).
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int MAXN=2e5+10;
const int inf=2147483647;
inline int read(){
int s=0,w=1;
char ch=getchar();
while(ch<'0'||ch>'9'){
if(ch=='-')w=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9'){
s=(s<<1)+(s<<3)+(ch^48);
ch=getchar();
}
return w==-1?-s:s;
}
int a[MAXN],n,m,v[MAXN];
int cur,t1,t2,q1[MAXN];
int q2[MAXN];
set<int>ST;
map<int,int>mp;
struct problem{
int opt,a,b,c;
}q[MAXN];
namespace SGT{
int rt[MAXN<<2],lc[MAXN<<5];
int tot,rc[MAXN<<5],sum[MAXN<<5];
inline int build(int &x){x=++tot;}
void update(int &x,int l,int r,int pos,int v){
if(!x)build(x);
sum[x]+=v;
if(l==r)return;
int mid=l+r>>1;
if(pos<=mid)update(lc[x],l,mid,pos,v);
else update(rc[x],mid+1,r,pos,v);
}
}
using namespace SGT;
//树状数组
inline int lowbit(int x){return x&(-x);}
void upd(int x,int pos,int v){for(;x<=n;x+=lowbit(x))update(rt[x],1,n,pos,v);}
void gtv(int x,int *A,int &p){p=0;for(;x;x-=lowbit(x))A[++p]=rt[x];}
int query(int l,int r,int k){
//ask for kth
if(l==r)return l;
int mid=l+r>>1;
int siz=0;
for(int i=1;i<=t1;++i)siz+=sum[lc[q1[i]]];
for(int i=1;i<=t2;++i)siz-=sum[lc[q2[i]]];
if(k<=siz){
for(int i=1;i<=t1;++i)q1[i]=lc[q1[i]];
for(int i=1;i<=t2;++i)q2[i]=lc[q2[i]];
return query(l,mid,k);
}
else{
for(int i=1;i<=t1;++i)q1[i]=rc[q1[i]];
for(int i=1;i<=t2;++i)q2[i]=rc[q2[i]];
return query(mid+1,r,k-siz);
}
}
int qrank(int l,int r,int k){
int ans=0;
for(int i=r;i;i-=lowbit(i)){
int u=rt[i];
int L=1,R=n;
while(sum[u]&&(L^R)){
int mid=L+R>>1;
if(k<=mid)R=mid,u=lc[u];
else ans+=sum[lc[u]],L=mid+1,u=rc[u];
}
}
for(int i=l-1;i;i-=lowbit(i)){
int u=rt[i];
int L=1,R=n;
while(sum[u]&&(L^R)){
int mid=L+R>>1;
if(k<=mid)R=mid,u=lc[u];
else ans-=sum[lc[u]],L=mid+1,u=rc[u];
}
}
return ans+1;
}
signed main(){
n=read(),m=read();
for(int i=1;i<=n;++i)a[i]=read(),ST.insert(a[i]);
for(int i=1;i<=m;++i){
q[i].opt=read();
if(q[i].opt!=3)q[i].a=read(),q[i].b=read(),q[i].c=read();
else q[i].a=read(),q[i].b=read(),ST.insert(q[i].b);
if(q[i].opt==1)ST.insert(q[i].c);
if(q[i].opt>=4)ST.insert(q[i].c);
}
for(set<int>::iterator it=ST.begin();it!=ST.end();++it)
mp[*it]=++cur,v[cur]=*it;
for(int i=1;i<=n;++i)a[i]=mp[a[i]];
for(int i=1;i<=m;++i){
if(q[i].opt==1)q[i].c=mp[q[i].c];
if(q[i].opt==3)q[i].b=mp[q[i].b];
if(q[i].opt>=4)q[i].c=mp[q[i].c];
}
n+=m;
for(int i=1;i<=n;++i)upd(i,a[i],1);
for(int i=1;i<=m;++i){
if(q[i].opt==1){
gtv(q[i].b,q1,t1);
gtv(q[i].a-1,q2,t2);
printf("%lld\n",qrank(q[i].a,q[i].b,q[i].c));
}
else if(q[i].opt==2){
gtv(q[i].b,q1,t1);
gtv(q[i].a-1,q2,t2);
printf("%lld\n",v[query(1,n,q[i].c)]);
}
else if(q[i].opt==3){
upd(q[i].a,a[q[i].a],-1);
upd(q[i].a,q[i].b,1);
a[q[i].a]=q[i].b;
}
else if(q[i].opt==4){
gtv(q[i].b,q1,t1);
gtv(q[i].a-1,q2,t2);
int rk=qrank(q[i].a,q[i].b,q[i].c);
if(rk<=1)printf("-2147483647\n");
else printf("%lld\n",v[query(1,n,rk-1)]);
}
else {
gtv(q[i].b,q1,t1);
gtv(q[i].a-1,q2,t2);
int rk=qrank(q[i].a,q[i].b,q[i].c+1);
if(rk>q[i].b-q[i].a+1)printf("2147483647\n");
else printf("%lld\n",v[query(1,n,rk)]);
}
}
return 0;
}