poj 2892 Tunnel Warfare (Splay Tree instead of Segment Tree)
poj上的一道数据结构题,这题正解貌似是Segment Tree,不过我用了Splay Tree来写,而且我个人认为,这题用Splay Tree会更好写!
先简单解释一下题意:有n个连续的村庄,有以下几种操作(1)破坏一个村庄(2)问某个村庄与多少个村庄相连(包括它本身)(3)重建之前破坏了的村庄。
这道题用Splay Tree做要用到一个类似DLX(Dancing Links)的操作,就是结点的假删除,用到这个题上可以说是相当的巧妙的。我们在删除结点前把结点Splay到根的位置,然后删除的只是子结点指向父结点(也就是现在要删除的结点)的指针。不过父结点不回收,也就是他的两个子结点的指针并不删除。这样我们就可以快速的找回恢复用到的子结点了!每种操作借助splay来完成,十分简洁。时间还是可以接受的400ms。
可能有人会问,恢复一个结点,假设编号为x,用到它的两个子结点不是就是处于它两侧的x-1和x+1吗?当然,这也是可以的。这时,我们就要将0和n+1指向Null了。
代码如下:
View Code
1 #include <cstdio> 2 #include <cstdlib> 3 #include <cstring> 4 #include <algorithm> 5 6 using namespace std; 7 8 const int maxn = 5e4 + 5; 9 10 struct Node { 11 int cnt; 12 bool ex; 13 Node *c[2], *p; 14 15 Node(int _cnt = 0, bool _ex = false) { 16 cnt = _cnt; 17 c[0] = c[1] = p = NULL; 18 ex = _ex; 19 } 20 } *Null, *Root; 21 22 void up(Node *rt) { 23 rt->cnt = rt->c[0]->cnt + rt->c[1]->cnt + 1; 24 } 25 26 void rotate(Node *x, bool right) { 27 Node *y = x->p; 28 29 y->c[!right] = x->c[right]; 30 if (x->c[right] != Null) x->c[right]->p = y; 31 x->p = y->p; 32 if (y->p != Null) { 33 if (y->p->c[0] == y) y->p->c[0] = x; 34 else y->p->c[1] = x; 35 } 36 x->c[right] = y; 37 y->p = x; 38 up(y); 39 40 if (Root == y) Root = x; 41 } 42 43 void splay(Node *x, Node *f) { 44 while (x->p != f) { 45 if (x->p->p == f) { 46 if (x->p->c[0] == x) rotate(x, 1); 47 else rotate(x, 0); 48 } else { 49 Node *y = x->p, *z = y->p; 50 51 if (z->c[0] == y) { 52 if (y->c[0] == x) { 53 rotate(y, 1); 54 rotate(x, 1); 55 } else { 56 rotate(x, 0); 57 rotate(x, 1); 58 } 59 } else { 60 if (y->c[0] == x) { 61 rotate(x, 1); 62 rotate(x, 0); 63 } else { 64 rotate(y, 0); 65 rotate(x, 0); 66 } 67 } 68 } 69 } 70 up(x); 71 } 72 73 Node *village[maxn]; 74 int s[maxn], top; 75 76 void build(int n) { 77 Null = new Node(); 78 Root = new Node(1, true); 79 Root->p = Root->c[0] = Root->c[1] = Null; 80 village[1] = Root; 81 82 for (int i = 2; i <= n; i++){ 83 village[i] = new Node(1, true); 84 village[i]->p = Root; 85 village[i]->c[0] = village[i]->c[1] = Null; 86 splay(village[i], Null); 87 } 88 top = -1; 89 } 90 91 void destroy(int x) { 92 splay(village[x], Null); 93 village[x]->c[0]->p = village[x]->c[1]->p = Null; 94 village[x]->ex = false; 95 s[++top] = x; 96 } 97 98 int query(int x) { 99 if (!village[x]->ex) return 0; 100 splay(village[x], Null); 101 return village[x]->cnt; 102 } 103 104 void rebuild() { 105 if (top < 0) return ; 106 107 int x = s[top]; 108 top--; 109 110 if (village[x]->c[0] != Null) splay(village[x]->c[0], Null); 111 if (village[x]->c[1] != Null) splay(village[x]->c[1], Null); 112 village[x]->c[0]->p = village[x]; 113 village[x]->c[1]->p = village[x]; 114 village[x]->ex = true; 115 up(village[x]); 116 } 117 118 int main() { 119 char op[3]; 120 int n, m, a; 121 122 // freopen("in", "r", stdin); 123 while (~scanf("%d%d", &n, &m)) { 124 build(n); 125 // printf("built %d\n", n); 126 while (m--) { 127 scanf("%s", op); 128 switch (op[0]) { 129 case 'D' : 130 scanf("%d", &a); 131 destroy(a); 132 break; 133 case 'Q' : 134 scanf("%d", &a); 135 printf("%d\n", query(a)); 136 break; 137 case 'R' : 138 rebuild(); 139 break; 140 } 141 // printf("m %d\n", m); 142 } 143 } 144 145 return 0; 146 }
做了这么多题,debug Splay Tree的时候主要是要检查是否每个修改结点的操作都伴随着spaly操作,另外一个就是更新函数有没有错。总的来说,一个splay就可以满足你的各种要求了!不过就像论文里说的,有些题目线段树能做就不要用Splay Tree,因为Splay Tree的常数大,代码量也是不少的,这个看回之前几个入门级的Splay Tree就可以发现了,每个都少则200+,多则300+行。
UPD:
同样的题,用线段树做的:
1 #include <cstdio> 2 #include <iostream> 3 #include <cstring> 4 #include <algorithm> 5 6 using namespace std; 7 8 const int N = 55555; 9 int lc[N << 2], rc[N << 2], pos[N], len[N << 2]; 10 11 #define lson l, m, rt << 1 12 #define rson m + 1, r, rt << 1 | 1 13 #define root 1, n, 1 14 15 void up(int rt) { 16 int ls = rt << 1, rs = rt << 1 | 1; 17 lc[rt] = lc[ls]; 18 if (lc[ls] == len[ls]) lc[rt] += lc[rs]; 19 rc[rt] = rc[rs]; 20 if (rc[rs] == len[rs]) rc[rt] += rc[ls]; 21 } 22 23 void build(int l, int r, int rt) { 24 len[rt] = r - l + 1; 25 // cout << l << ' ' << r << ' ' << rt << ' ' << len[rt] << endl; 26 if (l >= r) { 27 lc[rt] = rc[rt] = 1; 28 pos[l] = rt; 29 return ; 30 } 31 int m = l + r >> 1; 32 build(lson); 33 build(rson); 34 up(rt); 35 } 36 37 int st[N], top; 38 39 void destroy(int x) { 40 st[++top] = x; 41 // cout << "fuck " << x << endl; 42 x = pos[x]; 43 lc[x] = rc[x] = 0; 44 while ((x >>= 1) > 0) up(x); 45 } 46 47 void rebuild() { 48 int x = st[top--]; 49 // cout << "shit " << x << endl; 50 x = pos[x]; 51 lc[x] = rc[x] = 1; 52 while ((x >>= 1) > 0) up(x); 53 } 54 55 typedef pair<bool, bool> PBB; 56 typedef pair<int, PBB> PIBB; 57 58 PIBB query(int x, int l, int r, int rt) { 59 if (l >= r) return PIBB(lc[rt], PBB(lc[rt], lc[rt])); 60 int m = l + r >> 1; 61 PIBB tmp; 62 int ls = rt << 1, rs = rt << 1 | 1; 63 if (x <= m) { 64 tmp = query(x, lson); 65 if (tmp.second.second) { 66 tmp.first += lc[rs]; 67 tmp.second.second = lc[rs] == len[rs]; 68 } 69 } else { 70 tmp = query(x, rson); 71 if (tmp.second.first) { 72 tmp.first += rc[ls]; 73 tmp.second.first = rc[ls] == len[ls]; 74 } 75 } 76 return tmp; 77 } 78 79 int main() { 80 // freopen("in", "r", stdin); 81 int n, m, x; 82 char buf[3]; 83 while (~scanf("%d%d", &n, &m)) { 84 top = -1; 85 build(root); 86 while (m--) { 87 scanf("%s", buf); 88 if (buf[0] == 'D') { 89 scanf("%d", &x); 90 destroy(x); 91 } 92 if (buf[0] == 'R') rebuild(); 93 if (buf[0] == 'Q') { 94 scanf("%d", &x); 95 printf("%d\n", query(x, root).first); 96 } 97 } 98 } 99 return 0; 100 }
UPD:
还有一种威武霸气的好方法。想法源自LRC。
1 #include <cstdio> 2 #include <iostream> 3 #include <algorithm> 4 #include <cstring> 5 #include <set> 6 #include <stack> 7 8 using namespace std; 9 10 set<int> pos, neg; 11 stack<int> st; 12 int main() { 13 // freopen("in", "r", stdin); 14 int n, m, x; 15 while (~scanf("%d%d", &n, &m)) { 16 pos.clear(), neg.clear(); 17 while (!st.empty()) st.pop(); 18 pos.insert(0), pos.insert(n + 1); 19 neg.insert(0), neg.insert(-n - 1); 20 char buf[3]; 21 while (m--) { 22 scanf("%s", &buf); 23 if (buf[0] == 'D') { 24 scanf("%d", &x); 25 pos.insert(x), neg.insert(-x); 26 st.push(x); 27 } 28 if (buf[0] == 'R') { 29 pos.erase(st.top()), neg.erase(-st.top()); 30 st.pop(); 31 } 32 if (buf[0] == 'Q') { 33 scanf("%d", &x); 34 printf("%d\n", max(0, *neg.lower_bound(-x) + *pos.lower_bound(x) - 1)); 35 } 36 } 37 } 38 return 0; 39 }
——written by Lyon