123789456ye

已AFO

平衡树&&树套树模板

平衡树

Luogu

fhq

#define maxn 1000005
#define ls son[rt][0]
#define rs son[rt][1]
int rnd[maxn],val[maxn],son[maxn][2],siz[maxn];
int tot,root,x,y,z;
inline void read(int& x)
{
	x=0;char c=getchar();int f=1;
	while(!isdigit(c)) {if(c=='-') f=-1;c=getchar();}
	while(isdigit(c)) x=x*10+c-'0',c=getchar();
	x*=f;
}
inline void update(int rt)
{
	siz[rt]=siz[ls]+siz[rs]+1;
}
void split(int rt,int v,int& x,int& y)
{
	if(!rt) x=y=0;
	else
	{
		if(val[rt]<=v)
		{
			x=rt;
			split(rs,v,rs,y);
		}
		else
		{
			y=rt;
			split(ls,v,x,ls);
		}
		update(rt);
	}
	
}
int merge(int x,int y)
{
	if(!x||!y) return x|y;
	if(rnd[x]<=rnd[y])
	{
		son[x][1]=merge(son[x][1],y);
		update(x);
		return x;
	}
	else
	{
		son[y][0]=merge(x,son[y][0]);
		update(y);
		return y;
	}
}
inline int newnode(int v)
{
	siz[++tot]=1;
	val[tot]=v;
	rnd[tot]=rand();
	return tot;
}
inline void insert(int v)
{
	split(root,v,x,y);
	root=merge(merge(x,newnode(v)),y);
}
inline void pop(int v)
{
	split(root,v,x,z);
	split(x,v-1,x,y);
	y=merge(son[y][0],son[y][1]);
	root=merge(merge(x,y),z);
}
inline int getrank(int v)
{
	split(root,v-1,x,y);
	int ans=siz[x]+1;
	root=merge(x,y);
	return ans;
}
inline int getkth(int rt,int k)
{
	while(true)
	{
		if(k<=siz[ls]) rt=ls;
		else if(k==siz[ls]+1) return rt;
		else k-=siz[ls]+1,rt=rs;
	}
}
inline int lower(int v)
{
	split(root,v-1,x,y);
	int ans=val[getkth(x,siz[x])];
	merge(x,y);
	return ans;
}
inline int upper(int v)
{
	split(root,v,x,y);
	int ans=val[getkth(y,1)];
	merge(x,y);
	return ans;
}

树套树

Luogu

树状数组套权值线段树(单点修改版)

#include <bits/stdc++.h>
using namespace std;
#define maxn 100005
#define inf 1e8
#define INF 2147483647
int ls[maxn * 600], rs[maxn * 600], val[maxn * 600], tot, root[maxn];//注意空间是n(logn)^2的
int n, m, a[maxn];
template <typename T>
inline void read(T &x)
{
	x = 0;
	char c = getchar();
	while (!isdigit(c))
		c = getchar();
	while (isdigit(c))
		x = x * 10 + c - '0', c = getchar();
}

void modify(int &rt, int l, int r, int pos, int v)//单点加
{
	if (!rt)
		rt = ++tot;
	val[rt] += v;
	if (l == r)
		return;
	int mid = (l + r) >> 1;
	if (pos <= mid)
		modify(ls[rt], l, mid, pos, v);
	else
		modify(rs[rt], mid + 1, r, pos, v);
}
inline int lowbit(const int &x)
{
	return x & (-x);
}
void get_modify(int pos, int num, int val)//树状数组上加,pos为树状数组下标,num为对应的线段树下标
{
	for (; pos <= n; pos += lowbit(pos))
		modify(root[pos], 0, inf, num, val);
}
int tp1, tp2, s1[maxn], s2[maxn];
inline void init(int fr, int to)//初始化询问[fr,to]所需的线段树根节点,即[1,to]-[1,fr-1]
{
	--fr;
	tp1 = tp2 = 0;
	for (; fr; fr -= lowbit(fr))
		s1[++tp1] = root[fr];
	for (; to; to -= lowbit(to))
		s2[++tp2] = root[to];
}
inline void turnl()//所有节点往左节点跳
{
	for (int i = 1; i <= tp1; ++i)
		s1[i] = ls[s1[i]];
	for (int i = 1; i <= tp2; ++i)
		s2[i] = ls[s2[i]];
}
inline void turnr()//所有节点往右节点跳
{
	for (int i = 1; i <= tp1; ++i)
		s1[i] = rs[s1[i]];
	for (int i = 1; i <= tp2; ++i)
		s2[i] = rs[s2[i]];
}
int get_kth(int l, int r, int k)//求kth
{
	while (true)
	{
		if (l == r)
			return l;
		int sum = 0, mid = (l + r) >> 1;//sum表示左边节点的数字数量
		for (int i = 1; i <= tp1; ++i)
			sum -= val[ls[s1[i]]];
		for (int i = 1; i <= tp2; ++i)
			sum += val[ls[s2[i]]];
		if (k <= sum)//在左边
		{
			turnl();
			r = mid;
		}
		else//在右边,更新k
		{
			turnr();
			k -= sum, l = mid + 1;
		}
	}
}
int get_rank(int l, int r, int v)//求rk
{
	int ans = 0;
	while (true)
	{
		if (l == r)
			return ans;
		int mid = (l + r) >> 1;//当前的rt值
		if (v <= mid)//在它左边
		{
			turnl();
			r = mid;
		}
		else//右边,ans加上左节点的数字数量
		{
			for (int i = 1; i <= tp1; ++i)
				ans -= val[ls[s1[i]]];
			for (int i = 1; i <= tp2; ++i)
				ans += val[ls[s2[i]]];
			turnr();
			l = mid + 1;
		}
	}
}
inline int get_pre(int fr, int to, int v)
{
	int rk = get_rank(0, inf, v);
	if (rk)
	{
		init(fr, to);//注意!!!
		return get_kth(0, inf, rk);
	}
	return -INF;
}
inline int get_nex(int fr, int to, int v)
{
	int rk = get_rank(0, inf, v + 1);
	if (rk >= to - fr + 1)//防止越界
		return INF;
	init(fr, to);
	return get_kth(0, inf, rk + 1);
}
int main()
{
	//freopen("test.in", "r", stdin);
	read(n), read(m);
	for (int i = 1; i <= n; ++i)
		read(a[i]), get_modify(i, a[i], 1);//初始化,插入
	int op, x, y, z;
	for (int i = 1; i <= m; ++i)
	{
		read(op);
		read(x), read(y);
		if (op == 3)
		{
			get_modify(x, a[x], -1);
			get_modify(x, a[x] = y, 1);
		}
		else
		{
			read(z);
			init(x, y);
			if (op == 1)
				printf("%d\n", get_rank(0, inf, z) + 1);//查询的是小于z的数量,rk还要+1
			else if (op == 2)
				printf("%d\n", get_kth(0, inf, z));
			else if (op == 4)
				printf("%d\n", get_pre(x, y, z));
			else
				printf("%d\n", get_nex(x, y, z));
		}
	}
	return 0;
}

树状数组套权值线段树(区间修改版)

Luogu
如果你忘了树状数组怎么区间修改的话的话可以看看这个

#include<bits/stdc++.h>
using namespace std;
#define maxn 50005
#define ll long long
template<typename T>
inline void read(T& x)
{
	x=0;char c=getchar();int f=1;
	while(!isdigit(c)){if(c=='-') f=-1;c=getchar();}
	while(isdigit(c)) x=x*10+c-'0',c=getchar();
	x*=f;
}
int ls[maxn*400],rs[maxn*400],val[maxn*400],rt[2][maxn],tot,n,_n;
int update(int rt,int l,int r,int pos,int v)
{
	if(!rt) rt=++tot;
	val[rt]+=v;
	if(l==r) return rt;
	int mid=(l+r)>>1;
	if(pos<=mid) ls[rt]=update(ls[rt],l,mid,pos,v);
	else rs[rt]=update(rs[rt],mid+1,r,pos,v);
	return rt;
}
int tp1,tp2,s1[2][maxn],s2[2][maxn];
inline int lowbit(int x)
{
	return x&(-x);
}
int query(int l,int r,ll k)
{
	register int i;
	for(tp1=0,i=l-1;i;i-=lowbit(i)) s1[0][++tp1]=rt[0][i];
	for(tp1=0,i=l-1;i;i-=lowbit(i)) s1[1][++tp1]=rt[1][i];
	for(tp2=0,i=r;i;i-=lowbit(i)) s2[0][++tp2]=rt[0][i];
	for(tp2=0,i=r;i;i-=lowbit(i)) s2[1][++tp2]=rt[1][i];//找到要查询的线段树
	int mid;
	int L=1,R=n;
	while(true)
	{
		if(L==R) return L;
		mid=(L+R)>>1;
		ll sum=0;
		for(i=1;i<=tp1;++i) sum-=1ll*l*val[rs[s1[0][i]]];//注意l和L的区别,l=(l-1)+1
		for(i=1;i<=tp1;++i) sum+=val[rs[s1[1][i]]];
		for(i=1;i<=tp2;++i) sum+=1ll*(r+1)*val[rs[s2[0][i]]];
		for(i=1;i<=tp2;++i) sum-=val[rs[s2[1][i]]];
		if(k<=sum)
		{
			for(int i=1;i<=tp1;++i) s1[0][i]=rs[s1[0][i]];
			for(int i=1;i<=tp1;++i) s1[1][i]=rs[s1[1][i]];
			for(int i=1;i<=tp2;++i) s2[0][i]=rs[s2[0][i]];
			for(int i=1;i<=tp2;++i) s2[1][i]=rs[s2[1][i]];
			L=mid+1;
		}
		else
		{
			for(int i=1;i<=tp1;++i) s1[0][i]=ls[s1[0][i]];
			for(int i=1;i<=tp1;++i) s1[1][i]=ls[s1[1][i]];
			for(int i=1;i<=tp2;++i) s2[0][i]=ls[s2[0][i]];
			for(int i=1;i<=tp2;++i) s2[1][i]=ls[s2[1][i]];
			R=mid,k-=sum;
		}
	}
}

int main()
{
	//freopen("test.in","r",stdin);
	int m,op,l,r;ll v;
	register int i;
	read(_n),read(m);
	n=(_n<<1)+1;//直接整体平移
	while(m--)
	{
		read(op),read(l),read(r),read(v);
		if(op==1)
		{
			v+=_n+1;
			for(i=l;i<=_n;i+=lowbit(i)) rt[0][i]=update(rt[0][i],1,n,v,1);//l处+1,维护d[i]
			for(i=l;i<=_n;i+=lowbit(i)) rt[1][i]=update(rt[1][i],1,n,v,l);//l处+l,维护i*d[i]
			for(i=r+1;i<=_n;i+=lowbit(i)) rt[0][i]=update(rt[0][i],1,n,v,-1);
			for(i=r+1;i<=_n;i+=lowbit(i)) rt[1][i]=update(rt[1][i],1,n,v,-r-1);

		}
		else
			printf("%d\n",query(l,r,v)-_n-1);
	}
	return 0;
}
posted @ 2020-01-19 16:38  123789456ye  阅读(127)  评论(0编辑  收藏  举报