[DS 小计] 树套树

笔者很菜,只会最简单的树状数组套权值线段树。
不是,这玩意不就套娃吗,真 ex 啊

题目简要:

  • \(x\) 排名
  • 求排名为 \(x\) 的数
  • \(x\) 前驱后继

我们学了权值动态开点线段树就知道这些问题乱写就行了。
但是套上 \([l,r]\) 区间呢,无修呢?

我们会主席树这些乱写就行了。

但是套上有修呢?
权值线段树乱做即可。

如果这两个一结合,树套树就出来了。

在主席树中,我们使用前缀和的思想,相减得区间。
所以如果有修改,前缀和修改是 \(O(n\log n)\) 的。
所以不要前缀和了,直接树套树。

考虑使用树状数组,每次修改改树状数组上的点就好,然后合并即可。

一句话题解:
每个树状数组里面塞一棵权值线段树,然后修改暴力修改,查询在主席树中是两个点同时跳,在树套树中就是 \(\log n\) 个点同时跳。

于是就做完了。
时间复杂度 \(O(n\log n\log w)\)

点击查看代码
#include<bits/stdc++.h>
#define ls(x) tr[x].l
#define rs(x) tr[x].r
#define N 50005
#define ll long long
using namespace std;
const int M=1e8+1;
int n,m;
int a[N];
int root[N];
struct tnode{
	int l,r,v;
}tr[N*405];
int tot;
int q[N],lp,lq,p[N];
void modify(int &now,int l,int r,int p,int v)
{
	if(!now) now=++tot;
	tr[now].v+=v;
	if(l==r)return;
	int mid=(l+r)/2;
	if(p<=mid) modify(ls(now),l,mid,p,v);
	else modify(rs(now),mid+1,r,p,v);
}
int query(int now,int l,int r,int L,int R)
{
	if(!now) return 0;
	if(l>R||r<L) return 0;
	if(l>=L&&r<=R) return tr[now].v;
	int mid=(l+r)/2;
	return query(ls(now),l,mid,L,R)+query(rs(now),mid+1,r,L,R);
}
int Kquery(int l,int r,int k)
{
	if(l==r) return l;
	int mid=(l+r)/2,lv=0;
	for(int i=1;i<=lp;i++) lv-=tr[ls(p[i])].v;
	for(int i=1;i<=lq;i++) lv+=tr[ls(q[i])].v;
	if(k<=lv)
	{
		for(int i=1;i<=lp;i++) p[i]=ls(p[i]);
		for(int i=1;i<=lq;i++) q[i]=ls(q[i]);
		return Kquery(l,mid,k);
	}
	else
	{
		for(int i=1;i<=lp;i++) p[i]=rs(p[i]);
		for(int i=1;i<=lq;i++) q[i]=rs(q[i]);
		return Kquery(mid+1,r,k-lv);
	}
}
void add(int x,int v,int w)
{
	while(x<=n)
	{
		modify(root[x],0,M,v,w);
		x+=x&-x;
	}
}
void updata(int x,int v,bool tag=1)
{
	if(tag) add(x,a[x],-1); 
	add(x,v,1);
	a[x]=v;
}
void pre_query(int l,int r)
{
	lp=lq=0;
	l--;
	while(l)
	{
		p[++lp]=root[l];
		l-=l&-l;
	}
	while(r)
	{
		q[++lq]=root[r];
		r-=r&-r;
	}
}
int grank(int l,int r,int x)
{
	pre_query(l,r);
	int res=0;
	while(lp) res-=query(p[lp--],0,M,0,x-1);
	while(lq) res+=query(q[lq--],0,M,0,x-1);
	return res+1;
}
int kth(int l,int r,int k)
{
	pre_query(l,r);
	return Kquery(0,M,k);
}
int gnext(int l,int r,int x,int p)
{
	int k;
	if(p) k=grank(l,r,x+1);
	else k=grank(l,r,x)-1;
	if(k<1) return -2147483647;
	if(k>r-l+1) return 2147483647;
	return kth(l,r,k);
}
int main() 
{
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++)
	{
		int v;
		scanf("%d",&v);
		updata(i,v,0);
	}
	while(m--)
	{
		int opr,l,r,k;
		scanf("%d%d%d",&opr,&l,&r);
		if(opr^3) scanf("%d",&k);
		if(opr==1) printf("%d\n",grank(l,r,k));
		if(opr==2) printf("%d\n",kth(l,r,k));
		if(opr==3) updata(l,r);
		if(opr==4) printf("%d\n",gnext(l,r,k,0));
		if(opr==5) printf("%d\n",gnext(l,r,k,1));
	}
	return 0;
}

自我感觉已经很短了。

posted @ 2024-04-07 11:18  g1ove  阅读(8)  评论(0编辑  收藏  举报