可持久化线段树 学习笔记

引入

嗯嗯因为我打了一次测试

所以学了这个可持久化线段树(

怎么说其实这东西我很久之前学过,只是有点忘了

深刻认识到了写博客的重要性啦!(>ω・* )ノ

这东西其实很简单,也加强了对动态开点线段树的理解

思想

一个线段树,如果单点修改,只会修改一条链上的值,那你就直接每次新建一条链,就可以维护这棵线段树的历史版本啦!

具体看这个图

例题

区间第 \(k\) 小,经典问题

做法是建一个值域可持久化线段树,然后在上面二分

代码

主要是实现比较难

我直接贴代码,反正算法笔记是给自己看的ヾ(o・ω・)ノ

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 1e5+5, M = 1e9+5;
int n,m,a[N],b[N];
int rd()
{
	int x,f=0; char ch;
	while(!isdigit(ch=getchar())) if(ch=='-') f=1;
	for(x=(ch^48); isdigit(ch=getchar()); x=(x<<1)+(x<<3)+(ch^48));
	return f?-x:x;
}
struct Sagiri { int ls,rs,dat; } t[N<<5]; //动态开点要存左右儿子的编号
int rt[N],tot; //rt[i]表示 1~i 次修改后的线段树
void change(int &p, int q, int l, int r, int x, int v) //动态开点不需要 build,直接一个个 change 上去不香吗
{
	t[p=++tot]=t[q]; //这行很重要,表示新的线段树要在原来线段树的基础上改,要继承原来的左右儿子和 dat
	if(l==r) { t[p].dat+=v; return ; }
	int mid=(l+r)>>1;
	if(x<=mid) change(t[p].ls,t[q].ls,l,mid,x,v);
	else if(r>mid) change(t[p].rs,t[q].rs,mid+1,r,x,v);
	t[p].dat=t[t[p].ls].dat+t[t[p].rs].dat;
}
int query(int p, int q, int l, int r, int k) //可持久化线段树上二分 
{
	if(l==r) return l;
	int mid=(l+r)>>1;
	int lcnt=t[t[p].ls].dat-t[t[q].ls].dat; //表示的是左儿子的个数
	if(k<=lcnt) return query(t[p].ls,t[q].ls,l,mid,k); //如果询问的 k<=lcnt, 就在左子树里搜
	else return query(t[p].rs,t[q].rs,mid+1,r,k-lcnt); //否则就去右子树
}
int main()
{
	n=rd(),m=rd();
	for(int i=1; i<=n; i++) scanf("%d",&a[i]);
	for(int i=1; i<=n; i++) change(rt[i],rt[i-1],-M,M,a[i],1);
	int L,R,k;
	while(m--)
	{
		L=rd(),R=rd(),k=rd();
		printf("%d\n",query(rt[R],rt[L-1],-M,M,k));
	}
	return 0;
}

例题2

就是那次测试的 T4,

如果要维护区间前 \(k\) 大和,也是一样的

代码

struct Sagiri { int ls,rs,cnt; ll dat; } t[N<<5]; 
//日常把纱雾老婆拿来命名线段树结构体 要维护区间的个数 cnt 和区间之和 dat
int rt[N];
void change(int &p, int q, ll l, ll r, int x)
{
	t[p=++tot]=t[q];
	if(l==r) { t[p].cnt++; t[p].dat+=l; return ; }
	ll mid=(l+r)>>1;
	if(x<=mid) change(t[p].ls,t[q].ls,l,mid,x);
	else change(t[p].rs,t[q].rs,mid+1,r,x);
	t[p].dat=t[t[p].ls].dat+t[t[p].rs].dat;
	t[p].cnt=t[t[p].ls].cnt+t[t[p].rs].cnt;
}
ll query_sumk(int p, int q, ll l, ll r, int k)
{
	if(l==r) return (ll)l*k;
	ll mid=(l+r)>>1;
	int rcnt=t[t[p].rs].cnt-t[t[q].rs].cnt;
	if(k<=rcnt) return query_sumk(t[p].rs,t[q].rs,mid+1,r,k);
	else return query_sumk(t[p].ls,t[q].ls,l,mid,k-rcnt)+t[t[p].rs].dat-t[t[q].rs].dat; //这边记得要减
	
}
posted @ 2022-11-13 21:45  copper_carbonate  阅读(15)  评论(0编辑  收藏  举报