伸展树真是一种犀利的数据结构阿。一开始用数组模拟建树,先WA再TLE,代码很挫。然后问曾哥要了他的伸展树代码。略加改动和优化后就成了自己的模板。
/*
题目:SuperMemo
题目来源:POJ 3580
题目内容或思路:
伸展树
模板题
做题日期:2011.8.16
*/
#include <iostream>
#include <cstdio>
#include <string>
#include <cstring>
#include <string>
#include <queue>
#include <cmath>
#include <set>
#include <map>
#include <algorithm>
using namespace std;
#define forn(i, n) for (int i = 0; i < (int)(n); ++i)
#define clr(a, b) memset(a, b, sizeof(a))
#define SZ(a) ((int)a.size())
#define MP make_pair
#define PB push_back
#define inf 0x3f3f3f3f
typedef pair<int, int> pii;
typedef vector<int> vi;
typedef long long ll;
/*
* Splay Tree
* 所处理的数组下标为1-N,为实现方便,在0和N+1的位置增加一个key为inf的结点
* select()函数中的kth与实际下边的关系如下
* inf - num - num - num - num - inf
* 0 1 2 3 4 5
* 另外用null节点替换空指针
*/
const int MAX = 200005;
struct node {
int key, size;
int rev, minv, delta;
node *ch[2], *pre;
void add(int v) {
if (size == 0) return;
delta += v;
minv += v;
key += v;
}
void reverse() {
if (size == 0) return;
rev ^= 1;
swap(ch[0], ch[1]);
}
void update() {
size = ch[0]->size + ch[1]->size + 1;
minv = min(key, min(ch[0]->minv, ch[1]->minv));
}
void pushdown() {
if (delta) {
ch[0]->add(delta);
ch[1]->add(delta);
}
if (rev) {
ch[0]->reverse();
ch[1]->reverse();
}
delta = rev = 0;
}
};
int arr[MAX];
#define keytree root->ch[1]->ch[0]
class Splay {
int cnt, top;
node *stk[MAX], data[MAX];
public:
node *root, *null;
/*
* 获得一个新的节点,之前删除的节点会放到stk中以便再利用
*/
node *Newnode(int var) {
node *p;
if (top) p = stk[top--];
else p = &data[cnt++];
p->key = p->minv = var;
p->size = 1;
p->delta = p->rev = 0;
p->ch[0] = p->ch[1] = p->pre = null;
return p;
}
void init() {
top = cnt = 0;
null = Newnode(inf);
null->size = 0;
root = Newnode(inf);
root->ch[1] = Newnode(inf);
root->ch[1]->pre = root;
root->update();
}
/*
* 用arr数组中[l,r]区间内的值建树
*/
node *build(int l, int r) {
if (l > r) return null;
int mid = (l + r) >> 1;
node *p = Newnode(arr[mid]);
p->ch[0] = build(l, mid - 1);
p->ch[1] = build(mid + 1, r);
if (p->ch[0] != null)
p->ch[0]->pre = p;
if (p->ch[1] != null)
p->ch[1]->pre = p;
p->update();
return p;
}
/*
* 旋转操作, c=0 表示左旋, c=1 表示右旋
*/
void rotate(node *x, int c) {
node *y = x->pre;
y->pushdown();
x->pushdown();
y->ch[!c] = x->ch[c];
if (x->ch[c] != null)
x->ch[c]->pre = y;
x->pre = y->pre;
if (y->pre != null)
y->pre->ch[ y == y->pre->ch[1] ] = x;
x->ch[c] = y;
y->pre = x;
y->update();
if (y == root) root = x;
}
/*
* 旋转使x成为f的子节点,若f为null则x旋转为根节点
*/
void splay(node *x, node *f) {
x->pushdown();
while (x->pre != f) {
if (x->pre->pre == f) {
rotate(x, x->pre->ch[0] == x);
break;
}
node *y = x->pre;
node *z = y->pre;
int c = (y == z->ch[0]);
if (x == y->ch[c]) {
rotate(x, !c); rotate(x, c); // 之字形旋转
} else {
rotate(y, c); rotate(x, c); // 一字形旋转
}
}
// x->pushdown();
x->update();
}
/*
* 找到位置为k的节点,并将其升至x的儿子
*/
void select(int kth, node *x) {
node *cur = root;
while (true) {
cur->pushdown();
// cur->update();
int tmp = cur->ch[0]->size;
if (tmp == kth) break;
else if (tmp < kth) {
kth -= tmp + 1;
cur = cur->ch[1];
} else {
cur = cur->ch[0];
}
}
splay(cur, x);
}
/*
* 区间增加key值 "add(2, 4, 1)" on {1, 2, 3, 4, 5} results in {1, 3, 4, 5, 5}
*/
void add(int x, int y, int d) {
select(x - 1, null);
select(y + 1, root);
keytree->add(d);
splay(keytree, null);
}
/*
* 区间倒序 "reverse(2,4)" on {1, 2, 3, 4, 5} results in {1, 4, 3, 2, 5}
*/
void reverse(int x, int y) {
select(x - 1, null);
select(y + 1, root);
keytree->reverse();
}
/*
* 区间[x, y]循环右移d,实质是交换区间[a, b]和[b + 1, c],其中b = y - d % (y - x + 1)
* "revolve(2, 4, 2)" on {1, 2, 3, 4, 5} results in {1, 3, 4, 2, 5}
* 做法:将b+1位置的节点x升至根节点,将c+1位置的节点y升至x的右儿子,将c位置的节点z升至y的左儿子
* 将a-1位置的节点v升至x的左儿子,此时v的右儿子即是[a, b],将其赋给z的右儿子。
* 当d = 1时,节点x与节点z是同一节点,特殊处理。
*/
void revolve(int x, int y, int d) {
int len = y - x + 1;
d = (d % len + len) % len;
if (d == 0) return;
// if (d == 1) {
// select(y, null);
// select(y + 1, root);
// select(x - 1, root);
// node *p = root->ch[0]->ch[1];
// root->ch[0]->ch[1] = null;
// root->ch[0]->update();
// root->ch[1]->ch[0] = p;
// p->pre = root->ch[1];
// splay(p, null);
// }
if (d == 1) {
del(y);
insert(x - 1, stk[top]->key);
} else {
select(y - d + 1, null);
select(y + 1, root);
select(x - 1, root);
select(y, root->ch[1]);
node *p = root->ch[0]->ch[1];
root->ch[0]->ch[1] = null;
root->ch[0]->update();
root->ch[1]->ch[0]->ch[1] = p;
p->pre = root->ch[1]->ch[0];
splay(p, null);
}
}
/*
* 在X位置后插入值为Y的节点。
* "insert(2,4)" on {1, 2, 3, 4, 5} results in {1, 2, 4, 3, 4, 5}
* 做法:将X位置的节点a升至根节点,再将X+1位置的节点b升至a的右儿子
* 此时b的左儿子一定为空, 将新插入的节点作为b的左儿子。
*/
void insert(int x, int y) {
select(x, null);
select(x + 1, root);
keytree = Newnode(y);
keytree->pre = root->ch[1];
root->ch[1]->update();
splay(keytree, null);
}
/*
* 删除X位置的数。
* "DELETE(2)" on {1, 2, 3, 4, 5} results in {1, 3, 4, 5}
* 做法:找到并将其升至根节点,以其右子树的最左边节点替换之
*/
void del(int x) {
select(x, null);
node *oldRoot = root;
root = root->ch[1];
root->pre = null;
select(0, null);
root->ch[0] = oldRoot->ch[0];
root->ch[0]->pre = root;
root->update();
stk[++top] = oldRoot;
}
/*
* 求区间最小值
* "MIN(2,4)" on {1, 2, 3, 4, 5} is 2
* 做法:找到X-1位置上的节点a并将其升至根节点,再找到Y+1位置上的
* 的节点b并将其作为a的右儿子。则b的左儿子即所求区间。
*/
int getMin(int x, int y) {
select(x - 1, null);
select(y + 1, root);
return keytree->minv;
}
void debug() {vis(root);}
void vis(node* t) {
if (t == null) return;
vis(t->ch[0]);
printf("node%2d:lson %2d,rson %2d,pre %2d,sz=%2d,key=%2d\n",
t - data, t->ch[0] - data, t->ch[1] - data,
t->pre - data, t->size, t->key);
vis(t->ch[1]);
}
} spt;
int main() {
#ifdef CHEN_PC
freopen("in", "r", stdin);
#endif
int n;
while (scanf("%d", &n) != EOF) {
for (int i = 1; i <= n; i++) {
scanf("%d", &arr[i]);
}
spt.init();
if (n > 0) {
node *troot = spt.build(1, n);
spt.keytree = troot;
troot->pre = spt.root->ch[1];
spt.splay(troot, spt.null);
}
int q, x, y, d;
scanf("%d", &q);
char cmd[20];
// spt.debug();
while (q--) {
scanf("%s", cmd);
if (!strcmp(cmd, "ADD")) {
scanf("%d%d%d", &x, &y, &d);
spt.add(x, y, d);
} else if (!strcmp(cmd, "REVOLVE")) {
scanf("%d%d%d", &x, &y, &d);
spt.revolve(x, y, d);
} else if (!strcmp(cmd, "REVERSE")) {
scanf("%d%d", &x, &y);
spt.reverse(x, y);
} else if (!strcmp(cmd, "INSERT")) {
scanf("%d%d", &x, &y);
spt.insert(x, y);
} else if (!strcmp(cmd, "DELETE")) {
scanf("%d", &x);
spt.del(x);
} else if (!strcmp(cmd, "MIN")) {
scanf("%d%d", &x, &y);
printf("%d\n", spt.getMin(x, y));
}
// puts("=====================");
// spt.debug();
}
}
}