[算法学习] 平衡树 splay

基本信息维护

struct Splay {
	int sz, fa, val, cnt, ch[2];
	//sz 表示子树大小
	//ch[0 / 1]表示左右两个儿子
	//fa 表示父亲
	//cnt 表示该点的值出现了几次
	//val 节点的值
} t[N];

#define ls(p) (t[p].ch[0])
#define rs(p) (t[p].ch[1])

新建节点(采用回收空间的写法)

int top, rb[N];

inline int newnode(ll _val) {
	int x = top ? rb[top -- ] : ++ sz;
	t[x].ch[0] = t[x].ch[1] = t[x].fa = 0; t[x].sz = 1;
	t[x].val = t[x].sum = _val; t[x].cnt = _val < 0 ? 1 : 0;
	return x;
}

Rotate

1.X变到原来Y的位置
2.Y变成了 X原来在Y的 相对的那个儿子
3.Y的非X的儿子不变 X的 X原来在Y的 那个儿子不变
4.X的 X原来在Y的 相对的 那个儿子 变成了 Y原来是X的那个儿子
inline void rotate(int p) {
	int f = t[p].fa, d = t[f].ch[0] == p ? 1 : 0;
	t[p].fa = t[f].fa; t[f].fa = p; t[t[p].ch[d]].fa = f;
	t[f].ch[d ^ 1] = t[p].ch[d]; t[p].ch[d] = f;
	if(t[p].fa) t[t[p].fa].ch[t[t[p].fa].ch[0] == f ? 0 : 1] = p;
	pushup(f); pushup(p);
}

Splay

inline void splay(int p, int goal = 0) {
	while(t[p].fa != goal) {
		int f = t[p].fa, g = t[f].fa;
		if(g != goal) rotate((t[f].ch[0] == p) == (t[g].ch[0] == f) ? f : p);
		rotate(p);
	}
	if(!goal) rt = p;
}

建树

inline int build(int l, int r, int *a) {
	if(l > r) return 0;
	if(l == r) return newnode(a[l]);
	int mid = (l + r) >> 1, p = newnode(a[mid]);
	ls(p) = build(l, mid - 1, a);
	rs(p) = build(mid + 1, r, a);
	t[ls(p)].fa = t[rs(p)].fa = p;
	pushup(p);
	return p;
}

int main() {
	rt = build(1, n, a);
}

在第\(x\)位置后插入长度为\(tot\)的区间

inline void insert(int x, int tot, int *a) {
	int p = build(1, tot, a);
	int u = find_kth(rt, x); splay(u);
	int v = find_kth(rt, x + 1); splay(v, u);
	t[v].ch[0] = p; t[p].fa = v; 
	pushup(v); pushup(u);
}

删除区间[l,r]

inline void recycle(int p) {
	if (ls(p)) recycle(ls(p));
	if (rs(p)) recycle(rs(p));
	rb[++top] = p;
}

inline void erase(int l, int r) {
	int u = find_kth(rt, l - 1); splay(u);
	int v = find_kth(rt, r + 1); splay(v, u);
	recycle(t[v].ch[0]); t[v].ch[0] = 0;
	pushup(v); pushup(u);
}

区间求和

inline int query_sum(int l, int r) {
	int u = find_kth(rt, l - 1); splay(u);
	int v = find_kth(rt, r + 1); splay(v, u);
	printf("%lld\n", t[ls(v)].sum);
}

找到区间排名为\(k\)的数

inline int find_kth(int p, int k) {
//	pushdown(p); (有lazy tag时这么写)
	if(t[ls(p)].sz + 1 == k) return p;
	else if(t[ls(p)].sz >= k) return find_kth(ls(p), k);
	else return find_kth(rs(p), k - t[ls(p)].sz - 1);
}

循环版找到区间排名为\(k\)的数

inline int find_kth(int k) {
	int now = rt;
	while(now) {
		if(k <= t[ls(now)].sz) now = ls(now);
		else {
			k -= t[ls(now)].sz + t[now].cnt;
			if(k <= 0) return t[now].val;
			else now = rs(now);
		}
	}
}

循环版找前驱和后缀

inline int getpre(int v, bool getid = false) {
	int ans = -INF, id = 0, now = rt;
	while(now) {
		if(t[now].val >= v) now = ls(now);
		else {
			if(t[now].val > ans) ans = t[now].val, id = now;
			now = rs(now);
		} 
	}
	return getid ? id : ans;
}

inline int getsuf(int v, bool getid = false) {
	int ans = INF, id = 0, now = rt;
	while(now) {
		if(t[now].val <= v) now = rs(now);
		else {
			if(t[now].val < ans) ans = t[now].val, id = now;
			now = ls(now);
		}
	}
	return getid ? id : ans;
}

找节点的前驱和后缀

inline int pre() {
	int x = t[rt].ch[0];
	while(t[x].ch[1]) x = t[x].ch[1];
	return x;
}

inline int suf() {
	int x = t[rt].ch[1];
	while(t[x].ch[0]) x = t[x].ch[0];
	return x;
}

\(rt\)根下加入一个权值为\(v\)的节点

inline void insert(int rt, int v) {
	int now = rt, f = 0;
	while(now && t[now].val != v) f = now, now = t[now].ch[v > t[now].val];
	if(now) t[now].cnt ++ ;
	else {now = newnode(v, f); if(f) t[f].ch[v > t[f].val] = now;}
	splay(now);
}

\(rt\)根下加入\(p\)节点

inline void insert(int p, int x) {
	int now = p, f = 0; reset(x);
	while(now) f = now, now = t[now].ch[t[now].val < t[x].val];
	t[f].ch[t[f].val < t[x].val] = x; 
	t[x].fa = f; splay(x);
}

\(splay\) 启发式合并

inline void pushup(int p) {
	t[p].sz = t[ls(p)].sz + t[rs(p)].sz + 1;
}

inline int root(int x) {
	while(t[x].fa) x = t[x].fa;
	return x;
}

inline void insert(int p, int x) { // 将x节点接在p树上 
	int now = p, f = 0; reset(x);
	while(now) f = now, now = t[now].ch[t[now].val < t[x].val];
	t[f].ch[t[f].val < t[x].val] = x; 
	t[x].fa = f; splay(x);
}

inline void dfs(int p, int u) { //把p节点弄到u上去
	if(ls(p)) dfs(ls(p), u);
	if(rs(p)) dfs(rs(p), u);
	insert(root(u), p);
}

inline void merge(int x, int y) {
	if(x == y) return;
	if(t[x].sz > t[y].sz) swap(x, y);
	dfs(x, y); //把x并到y树上去
}
int main() {
      merge(root(a), root(b));
}
posted @ 2020-08-04 23:41  Hock  阅读(239)  评论(0编辑  收藏  举报