树套树学习笔记
树套树是处理区间问题/二维数点问题的一种常见的数据结构。
树套树也有很多种。最常见的一般就是线段树(树状数组/平衡树)套线段树(平衡树)共6种。
其实树套树的原理很简单,就是利用外层树的树高为 \(O(\log n)\) 和内层树允许动态开点的性质。依次保证空间复杂度 \(O(n\log^2 n)\)。
具体来说,进行一次单点修改,对应就是在外层树对应点的所有祖先分别进行一次内层树的修改。由于树高的限制,这样一次的时间复杂度为 \(O(\log^2 n)\)。
对于一次区间查询,利用线段树/平衡树的性质分到外层树的 \(O(\log n)\) 个节点上,对于这些节点对应的内层树分别进行查询,同样时间复杂度也为 \(O(\log^2 n)\)。
由此也可以看出,树套树处理的问题的局限性在于要求询问可以被分成 \(\log n\) 个区间分别处理后合并。
应用:
1.二维数点
由于区间 \([l,r]\) 本质就是一个二元组,区间之间的包含/相交关系也对应一个矩形,所以很多的区间问题都可以转换成二维数点问题。
具体来说这类问题通常可以转换成:动态修改一个点/矩形,查询某个矩形内点的信息。
这个可以通过树套树实现。具体例题:[ZJOI2017] 树状数组
当然,你硬要说离线大法好我也没什么话说。
2.动态区间第k大
题目看着很吓人,但仔细一想好想也没什么。区间第k大可以用主席树完成,动态第k大可以用平衡树/值域线段树完成。
那么把两个套一起不就好了。考虑线段树套平衡树,某个外层节点的内层平衡树维护的是该节点对应区间的信息。
考虑之前那个性质,我们可以把区间分到若干外层点上,这样可以 \(O(\log^2 n)\) 求出比k大的值。
但是我们要求的不是区间第 \(k\) 大吗?一种可行的方案要用线段树套值域线段树,是把这 \(\log n\) 外层点揪出来,然后通过比较左子树大小和与k的关系判断范围。
这样是 \(O(\log n\log a)\) 的,不过由于权值线段树的常数跑的并不快。事实上有一种偷懒的写法,直接二分区间第 \(k\) 大,然后求出区间比 \(k\) 大的数量即可。复杂度 \(O(n\log^3 n)\),可以勉强卡过。
例题:二逼平衡树_
这题中多了两个pre和nxt操作。这个其实很好处理,pre就是将 \(\log n\) 个区间分别询问pre,然后取最大值即可。nxt同理。
总时间复杂度 \(O(n\log^3 n)\)。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#define N 50010
#define M N*40
#define inf 2147483647
using namespace std;
int val[M],rnd[M],ch[M][2],siz[M],cnt;
void update(int u){siz[u]=siz[ch[u][0]]+siz[ch[u][1]]+1;}
void rot(int &u,int lf)
{
int v=ch[u][lf];
ch[u][lf]=ch[v][!lf],ch[v][!lf]=u;
update(u),update(v);u=v;
}
int new_node(int v){int u=++cnt;siz[u]=1;rnd[u]=rand();val[u]=v;return u;}
void insert(int &u,int v)
{
if(!u){u=new_node(v);return;}
siz[u]++;
if(v<=val[u]){insert(ch[u][0],v);if(rnd[ch[u][0]]<rnd[u]) rot(u,0);}
else{insert(ch[u][1],v);if(rnd[ch[u][1]]<rnd[u]) rot(u,1);}
}
void erase(int &u,int v)
{
if(val[u]==v)
{
if(!ch[u][0] || !ch[u][1]){u=ch[u][0]|ch[u][1];return;}
if(rnd[ch[u][0]]>rnd[ch[u][1]]) rot(u,1),erase(ch[u][0],v);
else rot(u,0),erase(ch[u][1],v);
}
else if(val[u]>v) erase(ch[u][0],v);
else erase(ch[u][1],v);
update(u);
}
int rnk(int u,int v)
{
if(!u) return 1;
if(val[u]>=v) return rnk(ch[u][0],v);
else return rnk(ch[u][1],v)+siz[ch[u][0]]+1;
}
int pre(int u,int v)
{
if(!u) return -inf;
if(val[u]<v) return max(val[u],pre(ch[u][1],v));
else return pre(ch[u][0],v);
}
int nxt(int u,int v)
{
if(!u) return inf;
if(val[u]>v) return min(val[u],nxt(ch[u][0],v));
else return nxt(ch[u][1],v);
}
int root[N<<2],a[N];
void build(int u,int l,int r)
{
for(int i=l;i<=r;i++) insert(root[u],a[i]);
if(l==r) return;
int mid=(l+r)>>1;
build(u<<1,l,mid),build(u<<1|1,mid+1,r);
}
void insert(int u,int l,int r,int p,int v)
{
erase(root[u],a[p]);
insert(root[u],v);
if(l==r){a[p]=v;return;}
int mid=(l+r)>>1;
if(p<=mid) insert(u<<1,l,mid,p,v);
else insert(u<<1|1,mid+1,r,p,v);
}
int rnk(int u,int l,int r,int L,int R,int k)
{
if(L<=l && r<=R) return rnk(root[u],k)-1;
int mid=(l+r)>>1,ans=0;
if(L<=mid) ans+=rnk(u<<1,l,mid,L,R,k);
if(R>mid) ans+=rnk(u<<1|1,mid+1,r,L,R,k);
return ans;
}
int pre(int u,int l,int r,int L,int R,int k)
{
if(L<=l && r<=R) return pre(root[u],k);
int mid=(l+r)>>1,ans=-inf;
if(L<=mid) ans=max(ans,pre(u<<1,l,mid,L,R,k));
if(R>mid) ans=max(ans,pre(u<<1|1,mid+1,r,L,R,k));
return ans;
}
int nxt(int u,int l,int r,int L,int R,int k)
{
if(L<=l && r<=R) return nxt(root[u],k);
int mid=(l+r)>>1,ans=inf;
if(L<=mid) ans=min(ans,nxt(u<<1,l,mid,L,R,k));
if(R>mid) ans=min(ans,nxt(u<<1|1,mid+1,r,L,R,k));
return ans;
}
int main()
{
int n,m;
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
build(1,1,n);
for(int i=1;i<=m;i++)
{
int opt,l,r,k;
scanf("%d%d%d",&opt,&l,&r);
if(opt==3) insert(1,1,n,l,r);
else
{
scanf("%d",&k);
if(opt==1) printf("%d\n",rnk(1,1,n,l,r,k)+1);
else if(opt==2)
{
int lf=0,rf=1e8,ans=0;
while(lf<=rf)
{
int mid=(lf+rf)>>1,p=rnk(1,1,n,l,r,mid);
if(p<k) lf=mid+1,ans=mid;
else rf=mid-1;
}
printf("%d\n",ans);
}
else if(opt==4) printf("%d\n",pre(1,1,n,l,r,k));
else if(opt==5) printf("%d\n",nxt(1,1,n,l,r,k));
}
}
return 0;
}