ch_g

ECUST_ACMer —— ch_g
  博客园  :: 首页  :: 新随笔  :: 联系 :: 订阅 订阅  :: 管理

[POJ][3580][SuperMemo][伸展树]

Posted on 2011-08-16 19:59  ch_g  阅读(769)  评论(0编辑  收藏  举报

伸展树真是一种犀利的数据结构阿。一开始用数组模拟建树,先WA再TLE,代码很挫。然后问曾哥要了他的伸展树代码。略加改动和优化后就成了自己的模板。

/*
题目:SuperMemo
题目来源:POJ 3580
题目内容或思路:
伸展树
模板题
做题日期:2011.8.16
*/
#include
<iostream>
#include
<cstdio>
#include
<string>
#include
<cstring>
#include
<string>
#include
<queue>
#include
<cmath>
#include
<set>
#include
<map>
#include
<algorithm>
using namespace std;

#define forn(i, n) for (int i = 0; i < (int)(n); ++i)
#define clr(a, b) memset(a, b, sizeof(a))
#define SZ(a) ((int)a.size())
#define MP make_pair
#define PB push_back
#define inf 0x3f3f3f3f
typedef pair
<int, int> pii;
typedef vector
<int> vi;
typedef
long long ll;

/*
* Splay Tree
* 所处理的数组下标为1-N,为实现方便,在0和N+1的位置增加一个key为inf的结点
* select()函数中的kth与实际下边的关系如下
* inf - num - num - num - num - inf
* 0 1 2 3 4 5
* 另外用null节点替换空指针
*/

const int MAX = 200005;

struct node {
int key, size;
int rev, minv, delta;
node
*ch[2], *pre;
void add(int v) {
if (size == 0) return;
delta
+= v;
minv
+= v;
key
+= v;
}
void reverse() {
if (size == 0) return;
rev
^= 1;
swap(ch[
0], ch[1]);
}
void update() {
size
= ch[0]->size + ch[1]->size + 1;
minv
= min(key, min(ch[0]->minv, ch[1]->minv));
}
void pushdown() {
if (delta) {
ch[
0]->add(delta);
ch[
1]->add(delta);
}
if (rev) {
ch[
0]->reverse();
ch[
1]->reverse();
}
delta
= rev = 0;
}
};

int arr[MAX];

#define keytree root->ch[1]->ch[0]
class Splay {
int cnt, top;
node
*stk[MAX], data[MAX];
public:
node
*root, *null;
/*
* 获得一个新的节点,之前删除的节点会放到stk中以便再利用
*/
node
*Newnode(int var) {
node
*p;
if (top) p = stk[top--];
else p = &data[cnt++];
p
->key = p->minv = var;
p
->size = 1;
p
->delta = p->rev = 0;
p
->ch[0] = p->ch[1] = p->pre = null;
return p;
}

void init() {
top
= cnt = 0;
null = Newnode(inf);
null->size = 0;
root
= Newnode(inf);
root
->ch[1] = Newnode(inf);
root
->ch[1]->pre = root;
root
->update();
}

/*
* 用arr数组中[l,r]区间内的值建树
*/
node
*build(int l, int r) {
if (l > r) return null;
int mid = (l + r) >> 1;
node
*p = Newnode(arr[mid]);
p
->ch[0] = build(l, mid - 1);
p
->ch[1] = build(mid + 1, r);
if (p->ch[0] != null)
p
->ch[0]->pre = p;
if (p->ch[1] != null)
p
->ch[1]->pre = p;
p
->update();
return p;
}

/*
* 旋转操作, c=0 表示左旋, c=1 表示右旋
*/
void rotate(node *x, int c) {
node
*y = x->pre;
y
->pushdown();
x
->pushdown();
y
->ch[!c] = x->ch[c];
if (x->ch[c] != null)
x
->ch[c]->pre = y;
x
->pre = y->pre;
if (y->pre != null)
y
->pre->ch[ y == y->pre->ch[1] ] = x;
x
->ch[c] = y;
y
->pre = x;
y
->update();
if (y == root) root = x;
}

/*
* 旋转使x成为f的子节点,若f为null则x旋转为根节点
*/
void splay(node *x, node *f) {
x
->pushdown();
while (x->pre != f) {
if (x->pre->pre == f) {
rotate(x, x
->pre->ch[0] == x);
break;
}
node
*y = x->pre;
node
*z = y->pre;
int c = (y == z->ch[0]);
if (x == y->ch[c]) {
rotate(x,
!c); rotate(x, c); // 之字形旋转
} else {
rotate(y, c); rotate(x, c);
// 一字形旋转
}
}
// x->pushdown();
x->update();
}

/*
* 找到位置为k的节点,并将其升至x的儿子
*/
void select(int kth, node *x) {
node
*cur = root;
while (true) {
cur
->pushdown();
// cur->update();
int tmp = cur->ch[0]->size;
if (tmp == kth) break;
else if (tmp < kth) {
kth
-= tmp + 1;
cur
= cur->ch[1];
}
else {
cur
= cur->ch[0];
}
}
splay(cur, x);
}

/*
* 区间增加key值 "add(2, 4, 1)" on {1, 2, 3, 4, 5} results in {1, 3, 4, 5, 5}
*/
void add(int x, int y, int d) {
select(x
- 1, null);
select(y
+ 1, root);
keytree
->add(d);
splay(keytree,
null);
}

/*
* 区间倒序 "reverse(2,4)" on {1, 2, 3, 4, 5} results in {1, 4, 3, 2, 5}
*/
void reverse(int x, int y) {
select(x
- 1, null);
select(y
+ 1, root);
keytree
->reverse();
}

/*
* 区间[x, y]循环右移d,实质是交换区间[a, b]和[b + 1, c],其中b = y - d % (y - x + 1)
* "revolve(2, 4, 2)" on {1, 2, 3, 4, 5} results in {1, 3, 4, 2, 5}
* 做法:将b+1位置的节点x升至根节点,将c+1位置的节点y升至x的右儿子,将c位置的节点z升至y的左儿子
* 将a-1位置的节点v升至x的左儿子,此时v的右儿子即是[a, b],将其赋给z的右儿子。
* 当d = 1时,节点x与节点z是同一节点,特殊处理。
*/
void revolve(int x, int y, int d) {
int len = y - x + 1;
d
= (d % len + len) % len;
if (d == 0) return;

// if (d == 1) {
// select(y, null);
// select(y + 1, root);
// select(x - 1, root);
// node *p = root->ch[0]->ch[1];
// root->ch[0]->ch[1] = null;
// root->ch[0]->update();
// root->ch[1]->ch[0] = p;
// p->pre = root->ch[1];
// splay(p, null);
// }

if (d == 1) {
del(y);
insert(x
- 1, stk[top]->key);
}
else {
select(y
- d + 1, null);
select(y
+ 1, root);
select(x
- 1, root);
select(y, root
->ch[1]);
node
*p = root->ch[0]->ch[1];
root
->ch[0]->ch[1] = null;
root
->ch[0]->update();
root
->ch[1]->ch[0]->ch[1] = p;
p
->pre = root->ch[1]->ch[0];
splay(p,
null);
}
}

/*
* 在X位置后插入值为Y的节点。
* "insert(2,4)" on {1, 2, 3, 4, 5} results in {1, 2, 4, 3, 4, 5}
* 做法:将X位置的节点a升至根节点,再将X+1位置的节点b升至a的右儿子
* 此时b的左儿子一定为空, 将新插入的节点作为b的左儿子。
*/
void insert(int x, int y) {
select(x,
null);
select(x
+ 1, root);
keytree
= Newnode(y);
keytree
->pre = root->ch[1];
root
->ch[1]->update();
splay(keytree,
null);
}

/*
* 删除X位置的数。
* "DELETE(2)" on {1, 2, 3, 4, 5} results in {1, 3, 4, 5}
* 做法:找到并将其升至根节点,以其右子树的最左边节点替换之
*/
void del(int x) {
select(x,
null);
node
*oldRoot = root;
root
= root->ch[1];
root
->pre = null;
select(
0, null);
root
->ch[0] = oldRoot->ch[0];
root
->ch[0]->pre = root;
root
->update();
stk[
++top] = oldRoot;
}

/*
* 求区间最小值
* "MIN(2,4)" on {1, 2, 3, 4, 5} is 2
* 做法:找到X-1位置上的节点a并将其升至根节点,再找到Y+1位置上的
* 的节点b并将其作为a的右儿子。则b的左儿子即所求区间。
*/
int getMin(int x, int y) {
select(x
- 1, null);
select(y
+ 1, root);
return keytree->minv;
}

void debug() {vis(root);}
void vis(node* t) {
if (t == null) return;
vis(t
->ch[0]);
printf(
"node%2d:lson %2d,rson %2d,pre %2d,sz=%2d,key=%2d\n",
t
- data, t->ch[0] - data, t->ch[1] - data,
t
->pre - data, t->size, t->key);
vis(t
->ch[1]);
}
} spt;

int main() {
#ifdef CHEN_PC
freopen(
"in", "r", stdin);
#endif
int n;
while (scanf("%d", &n) != EOF) {
for (int i = 1; i <= n; i++) {
scanf(
"%d", &arr[i]);
}
spt.init();
if (n > 0) {
node
*troot = spt.build(1, n);
spt.keytree
= troot;
troot
->pre = spt.root->ch[1];
spt.splay(troot, spt.
null);
}
int q, x, y, d;
scanf(
"%d", &q);
char cmd[20];
// spt.debug();
while (q--) {
scanf(
"%s", cmd);
if (!strcmp(cmd, "ADD")) {
scanf(
"%d%d%d", &x, &y, &d);
spt.add(x, y, d);
}
else if (!strcmp(cmd, "REVOLVE")) {
scanf(
"%d%d%d", &x, &y, &d);
spt.revolve(x, y, d);
}
else if (!strcmp(cmd, "REVERSE")) {
scanf(
"%d%d", &x, &y);
spt.reverse(x, y);
}
else if (!strcmp(cmd, "INSERT")) {
scanf(
"%d%d", &x, &y);
spt.insert(x, y);
}
else if (!strcmp(cmd, "DELETE")) {
scanf(
"%d", &x);
spt.del(x);
}
else if (!strcmp(cmd, "MIN")) {
scanf(
"%d%d", &x, &y);
printf(
"%d\n", spt.getMin(x, y));
}
// puts("=====================");
// spt.debug();
}
}
}