Splay 伸展树

废话不说,有篇论文可供参考:杨思雨:《伸展树的基本操作与应用》

Splay的好处可以快速分裂和合并。

===============================14.07.26更新=============================

实在看不惯那充满bug的指针树了!动不动就re!动不动就re!调试调个老半天,谁有好的调试技巧为T_T

好不容易写了个模板splay出来,指针的,好写,核心代码rotate和splay能压缩到10行。

#include <cstdio>
using namespace std;

class Splay {
public:
//这个版本的splay会更改null的fa和ch,但是对结果没影响,这点要牢记(其实处理好rot里的fa->fa->setc那里就行了,也就是特判一下)
	struct node {
		node* ch[2], *fa;
		int key, size;
		node() { ch[0]=ch[1]=fa=0; key=size=0; }
		void pushup() { size=ch[0]->size+ch[1]->size+1; }
		bool d() { return fa->ch[1]==this; }
		void setc(node* c, bool d) { ch[d]=c; c->fa=this; }
	}*null, * root;
	Splay() { 
		null=new node;
		null->ch[0]=null->ch[1]=null->fa=null;
		root=null;
	}
	void rot(node* rt) {
		node* fa=rt->fa; bool d=rt->d();
		fa->fa->setc(rt, fa->d()); //这里要注意,因为d()返回的是路径,所以不要高反了。
		fa->setc(rt->ch[!d], d); 
		rt->setc(fa, !d);
		fa->pushup();
		if(root==fa) root=rt;
	}
	node* newnode(const int &key) {
		node* ret=new node();
		ret->key=key; ret->size=1;
		ret->ch[0]=ret->ch[1]=ret->fa=null;
		return ret;
	}
	void splay(node* rt, node* fa) {
		while(rt->fa!=fa)
			if(rt->fa->fa==fa) rot(rt);
			else rt->d()==rt->fa->d()?(rot(rt->fa), rot(rt)):(rot(rt), rot(rt));
		rt->pushup();
	}
	void insert(const int &key) {
		if(root==null) { root=newnode(key); return; }
		node* t=root;
		while(t->ch[key>t->key]!=null) t=t->ch[key>t->key];
		node* c=newnode(key);
		t->setc(c, key>t->key);
		t->pushup();
		splay(c, null);
	}
	void remove(const int &key) {
	//remove是各种坑,各种注意细节判断null
		node* t=root;
		while(t!=null && t->key!=key) t=t->ch[key>t->key];
		if(t==null) return;
		splay(t, null);
		node* rt=root->ch[0];
		if(rt==null) rt=root->ch[1];
		else {
			node* m=rt->ch[1];
			while(m!=null && m->ch[1]!=null) m=m->ch[1];
			if(m!=null) splay(m, root);
			rt=root->ch[0];
			rt->setc(root->ch[1], 1);
		}
		delete root;
		root=rt;
		root->fa=null;
		if(root!=null) root->pushup(); //这里一定要注意,因为咱这是会修改null的孩子的!!
	}
	int rank(const int &key, node* rt) {
		if(rt==null) return 0;
		int s=rt->ch[0]->size+1;
		if(key==rt->key) return s;
		if(key>rt->key) return s+rank(key, rt->ch[1]);
		else return rank(key, rt->ch[0]);
	}
	node* select(const int &k, node* rt) {
		if(rt==null) return null;
		int s=rt->ch[0]->size+1;
		if(s==k) return rt;
		if(k>s) return select(k-s, rt->ch[1]);
		else return select(k, rt->ch[0]);
	}
};

int main() {
	
	return 0;
}

 

 

==============================================================

上代码(指针版,可能有潜在bug)

#include <iostream>
#include <string>
using namespace std;

#define F(rt) rt-> pa
#define K(rt) rt-> key
#define CH(rt, d) rt-> ch[d]
#define C(rt, d) (K(rt) > d ? 0 : 1)
#define NEW(d) new Splay(d)
#define PRE(rt) F(rt) = CH(rt, 0) = CH(rt, 1) = null

struct Splay {
	Splay* ch[2], *pa;
	int key;
	Splay(int d = 0) : key(d) { ch[0] = ch[1] = pa = NULL; }
};

typedef Splay* tree;
tree null = new Splay, root = null;

void rot(tree& rt, int d) {
	tree k = CH(rt, d^1), u = F(rt); int flag = CH(u, 1) == rt;
	CH(rt, d^1) = CH(k, d); if(CH(k, d) != null) F(CH(k, d)) = rt;
	CH(k, d) = rt; F(rt) = k; rt = k; F(rt) = u;
	if(u != null) CH(u, flag) = k; //这里不能等于rt,否则会死的
}

void splay(tree nod, tree& rt) {
	if(nod == null) return;
	tree pa = F(rt); //要先记录rt的父亲节点,因为rt会变的,所以判断条件里不能有rt
	while(F(nod) != pa) {
		if(F(nod) == rt)
			rot(rt, CH(rt, 0) == nod); //当这步完后,rt会改变
		else {
			int d  = CH(F(F(nod)), 0) == F(nod); //记录nod父亲是nod父亲的父亲的哪个儿子,1为左儿子,0为右儿子
			int d2 = CH(F(nod), 0)	  == nod; //同上
			if(d == d2) { rot(F(F(nod)), d); rot(F(nod), d2); } //当路线相同时,先转父亲的父亲,后转父亲
			else 		{ rot(F(nod), d2);	 rot(F(nod), d); } //当路线不同时,先转父亲,再转父亲的父亲 (在这里,第一次转后,F(nod)就是父亲的父亲了,因为第一次转后,nod改变了啦。)
		}
	}
	rt = nod; //在这里还要重新赋值给引用的指针,不然照样会死
}

tree maxmin(tree rt, int d) { //d=0时找最小,d=1时找最大
	if(rt == null) return null;
	while(CH(rt, d) != null) rt = CH(rt, d);
	return rt;
}

tree ps(tree rt, int d) { //d=0时找前驱,d=1时找后继
	if(rt == null) return null;
	rt = CH(rt, d);
	return maxmin(rt, d^1);
}

tree search(tree& rt, int d) {
	if(rt == null) return null;
	tree t = rt;
	while(t != null && K(t) != d) t = CH(t, C(t, d));
	splay(t, rt); //搜索到记得splay
	return t;
}

void insert(tree& rt, int d) {
	tree q = NULL, t = rt;
	while(t != null) q = t, t = CH(t, C(t, d));
	t = new Splay(d);
	PRE(t); //初始化他的3个指针
	if(q) F(t) = q, CH(q, C(q, d)) = t; //如果他有父亲,那么就要初始化他的父亲指针和他父亲的儿子指针
	else rt = t; //如果没有,那么t就是以rt为根的这棵树的根
	splay(t, rt);
}

void del(tree& rt, int d) {
	if(search(rt, d) == null) return; //搜索的这一次,如果有就顺便将他splay到根了,所以下面不需要再splay
	tree t = rt; //这个根就是要删除的
	if(CH(t, 0) == null) t = CH(rt, 1); //如果没有左子女,就直接赋值喽
	else { //如果有,就要将树的根设置为左子女最大元素,然后将右子树赋值为新树的右子女
		t = CH(rt, 0); //t就是新根
		splay(ps(rt, 0), t); //找到旧根的后继(其实就是新根的最大值)并设置为新树的根
		CH(t, 1) = CH(rt, 1); //然后将右子树设置为原根的右子树
		if(CH(rt, 1) != null) F(CH(rt, 1)) = t; //在这里,如果有这个右子树,还要将右子树的父亲指针设置为新根,否则又要挂
	}
	delete rt; //删除
	F(t) = null; //设置新根的父亲指针
	rt = t; 
}

void out(string str) {
	cout << str;
}

int main() {
	out("1: insert\n2: del\n3: search\n4: max\n5: min\n");
	PRE(null); //一定记得初始化null指针的其它指针,使他们全部指向自己
	int c, t;
	tree a;
	while(cin >> c) {
		switch(c) {
		case 1: cin >> t;
				insert(root, t);
				break;
		case 2: cin >> t;
				del(root, t);
				break;
		case 3: cin >> t;
				if(search(root, t) == null) out("Not here\n");
				else out("Is here!\n");
				break;
		case 4: a = maxmin(root, 1);
				if(a != null) cout << a-> key << endl;
				else out("Warn!\n");
				break;
		case 5: a = maxmin(root, 0);
				if(a != null) cout << a-> key << endl;
				else out("Warn!\n");
				break;
		default:
				break;
		}
	}

	return 0;
}

  

另写了一个数组的,比指针的好写多了,而且很精简。

自底向上的splay,还没有自带回收内存,支持区间修改,区间求和的模板。(是http://www.wikioi.com/problem/1082/的题)

#include <cstdio>
using namespace std;

#define K(x) key[x]
#define S(x) size[x]
#define C(x, d) ch[x][d]
#define F(x) fa[x]
#define L(x) ch[x][0]
#define R(x) ch[x][1]
#define keytree L(R(root))
#define LL long long

const int maxn = 222222;
int size[maxn], key[maxn], fa[maxn], ch[maxn][2], add[maxn];
LL sum[maxn];
int tot, root;
int arr[maxn];

void newnode(int& x, int k, int f) {
	x = ++tot; F(x) = f; S(x) = 1;
	K(x) = sum[x] = k;
}

void pushup(int x) {
	S(x) = S(R(x)) + S(L(x)) + 1;
	sum[x] = sum[R(x)] + sum[L(x)] + K(x) + add[x];
}

void pushdown(int x) {
	if(add[x]) {
		K(x) += add[x];
		add[R(x)] += add[x];
		add[L(x)] += add[x];
		sum[R(x)] += (LL)add[x] * (LL)S(R(x));
		sum[L(x)] += (LL)add[x] * (LL)S(L(x));
		add[x] = 0;
	}
}

void rot(int x, int c) {
	int y = F(x);
	pushdown(y); pushdown(x);
	C(y, !c) = C(x, c); F(C(x, c)) = y;
	C(x, c) = y; F(x) = F(y); F(y) = x;
	if(F(x)) C(F(x), R(F(x)) == y) = x;
	pushup(y);
}

void splay(int x, int y) {
	if(!x) return;
	pushdown(x);
	while(F(x) != y) {
		if(F(F(x)) == y) rot(x, L(F(x)) == x);
		else {
			int d1 = L(F(F(x))) == F(x);
			int d2 = L(F(x)) == x;
			if(d1 == d2) { rot(F(x), d1); rot(x, d2); }
			else { rot(x, d2); rot(x, d1); }
		}
	}
	pushup(x);
	if(!y) root = x;
}

void insert(int k) {
	int x = root;
	while(C(x, k > K(x))) x = C(x, k > K(x));
	newnode(C(x, k > K(x)), k, x);
	splay(C(x, k > K(x)), 0);
}

int sel(int k, int x) {
	for(pushdown(x); S(L(x))+1 != k; pushdown(x))
		if(k <= S(L(x))) x = L(x);
		else k -= (S(L(x))+1), x = R(x);
	return x;
}

//特定的select,因为多插了2个边界节点,所以和原版的select有区别
int vsel(int k, int x) {
	for(pushdown(x); S(L(x)) != k; pushdown(x))
		if(k < S(L(x))) x = L(x);
		else k -= (S(L(x)) + 1), x = R(x);
	return x;
}

void build(int l, int r, int& rt, int f) {
	if(l > r) return;
	int mid = (l+r) >> 1;
	newnode(rt, arr[mid], f);
	build(l, mid-1, L(rt), rt); build(mid+1, r, R(rt), rt);
	pushup(rt);
}

void query() {
	int l, r;
	scanf("%d%d", &l, &r);
	splay(vsel(l-1, root), 0);
	splay(vsel(r+1, root), root);
	printf("%lld\n", sum[keytree]);
}

void updata() {
	int l, r, _add;
	scanf("%d%d%d", &l, &r, &_add);
	splay(vsel(l-1, root), 0);
	splay(vsel(r+1, root), root);
	sum[keytree] += (LL)_add * (LL)S(keytree);
	add[keytree] += (LL)_add;
}

int n, q, t;

void init() {
	for(int i = 1; i <= n; ++i) scanf("%d", &arr[i]);
	newnode(root, -1, 0);
	newnode(R(root), -1, root);
	S(root) = 2;
	build(1, n, keytree, R(root));
	pushup(R(root));
	pushup(root);
}

int main() {
	scanf("%d", &n);
	init();
	scanf("%d", &q);
	while(q--) {
		scanf("%d", &t);
		if(t == 1) updata();
		else query();
	}
	return 0;
}

  

posted @ 2014-02-01 18:55  iwtwiioi  阅读(729)  评论(0编辑  收藏  举报