[codevs 1514] 书架
http://codevs.cn/problem/1514/
题解:
Splay,因为没办法根据给出的书的编号确定书的位置,也就没办法做到log(n)的查询,所以采用自底向上的伸展方式,那么就需要用pa数组记录节点的上个结点。
算法实现上建立了两个虚拟节点来避免溢出——据HZWER。
加了很多注解。
体会了一下,因为move操作的统一性,所以加上两个虚拟节点后有利于继续维护插入操作的统一性。否则除非给移动到顶和底层的另写函数,没办法用一个move函数解决了否则溢出,因为空节点不能作为根节点。
写了蚱蜢以后才逐渐理解。
第一个用的HZWER的splay,第二个是我写的。
代码:
版本1:
总时间耗费: 1637ms
总内存耗费: 2 kB
总内存耗费: 2 kB
#include<cstdio> #include<cstring> using namespace std; const int INF = 1e9 + 7; const int maxn = 80000 + 10; int n, m, root; int ch[maxn][2], p[maxn], a[maxn], s[maxn], v[maxn], id[maxn]; void update(int k) { s[k] = s[ch[k][0]] + s[ch[k][1]] + 1; } void rotate(int x, int &k) { int y = p[x], z = p[y]; int d = ch[y][0] == x ? 0 : 1; int d2 = ch[z][0] == y ? 0 : 1; if(y == k) k = x; else ch[z][d2] = x; p[x] = z; p[y] = x; p[ch[x][d^1]] = y; ch[y][d] = ch[x][d^1]; ch[x][d^1] = y; update(y); update(x); } //自底向上的伸展 void splay(int x, int &k) { while(x != k) { int y = p[x], z = p[y]; if(y != k) { if(ch[y][0] == x ^ ch[z][0] == y) rotate(x, k); else rotate(y, k); } rotate(x, k); } } //类似线段树的build void build(int l, int r, int pa) { if(l > r) return; if(l == r) { v[l] = a[l]; s[l] = 1; p[l] = pa; if(l < pa) ch[pa][0] = l; else ch[pa][1] = l; return; } int mid = (l+r) >> 1; build(l, mid-1, mid); build(mid+1, r, mid); v[mid] = a[mid]; p[mid] = pa; update(mid); if(mid < pa) ch[pa][0] = mid; else ch[pa][1] = mid; } //在这个树中查询第rank个 int find(int k, int rank) { int l = ch[k][0], r = ch[k][1]; if(s[l]+1 == rank) return k; else if(s[l] >= rank) return find(l, rank); else return find(r, rank-s[l]-1); } //删除排在第k位的 void remove(int k) { int x, y, z; x = find(root, k-1); y = find(root, k+1); splay(x, root); splay(y, ch[x][1]); z = ch[y][0]; ch[y][0] = 0; p[z] = s[z] = 0; update(y); update(x); } void move(int k, int val) { int o = id[k], x, y, rank; splay(o, root); rank = s[ch[o][0]] + 1; remove(rank); //x是要插入的位置的上一个点,y是要插入的位置 //最后插入到y的左儿子也就是取代了y原来的位置 if(val == INF) x = find(root, n), y = find(root, n+1); //插入到底部 此时共n+1个点 成为第n+1个点 else if(val == -INF) x = find(root, 1), y = find(root, 2); //插入到顶部 成为第2个点 else x = find(root, rank+val-1), y = find(root, rank+val); //插入到中间 成为第rank+val个点(rank的值是原来的排名+1) splay(x, root); splay(y, ch[x][1]); s[o] = 1; p[o] = y; ch[y][0] = o; update(y); update(x); } int main() { scanf("%d%d", &n, &m); //加入1和n+2两个虚拟结点,避免溢出 for(int i = 2; i <= n+1; i++) scanf("%d", &a[i]), id[a[i]] = i; build(1, n+2, 0); root = (n+3) >> 1; char cmd[10]; int S, T; for(int i = 1; i <= m; i++) { scanf("%s%d", cmd, &S); //避免了标准函数读char读到回车什么奇怪的东西 switch(cmd[0]) { case 'T': move(S, -INF); break; case 'B': move(S, INF); break; case 'I': scanf("%d", &T); move(S, T); break; case 'A': splay(id[S], root); printf("%d\n", s[ch[id[S]][0]]-1); break; case 'Q': printf("%d\n", v[find(root, S+1)]); break; } } return 0; }
版本2:
总时间耗费: 1067ms
总内存耗费: 3 MB
#include<cstdio> #include<cstring> using namespace std; const int INF = 1e9 + 7; const int maxn = 80000 + 10; int n, m, root; int ch[maxn][2], p[maxn], a[maxn], s[maxn], v[maxn], id[maxn]; void update(int k) { s[k] = s[ch[k][0]] + s[ch[k][1]] + 1; } void rotate(int& px, int& x, int d) { int t = ch[x][d]; ch[x][d] = px; ch[px][d^1] = t; p[x] = p[px]; p[px] = x; p[t] = px; update(px); update(x); px = x; } void splay(int x, int& k) { while(x != k) { int y = p[x], z = p[y]; int d = ch[y][0] == x ? 0 : 1; int d2 = ch[z][0] == y ? 0 : 1; if(y != k) rotate(ch[z][d2], x, d^1); else rotate(k, x, d^1); } } void build(int L, int R, int P, int d) { if(L == R) { s[L] = 1; p[L] = P; ch[P][d] = L; return; } int M = (L+R) >> 1; p[M] = P; ch[P][d] = M; if(M-1 >= L) build(L, M-1, M, 0); if(R >= M+1) build(M+1, R, M, 1); update(M); } int find(int k, int rank) { int l = ch[k][0], r = ch[k][1]; if(s[l]+1 == rank) return k; else if(s[l] >= rank) return find(l, rank); else return find(r, rank-s[l]-1); } void remove(int k) { int x, y, z; x = find(root, k-1); y = find(root, k+1); splay(x, root); splay(y, ch[x][1]); z = ch[y][0]; ch[y][0] = 0; p[z] = s[z] = 0; update(y); update(x); } void move(int k, int val) { int o = id[k], x, y, rank; splay(o, root); rank = s[ch[o][0]] + 1; remove(rank); if(val == INF) x = find(root, n), y = find(root, n+1); else if(val == -INF) x = find(root, 1), y = find(root, 2); else x = find(root, rank+val-1), y = find(root, rank+val); splay(x, root); splay(y, ch[x][1]); s[o] = 1; p[o] = y; ch[y][0] = o; update(y); update(x); } int main() { scanf("%d%d", &n, &m); for(int i = 2; i <= n+1; i++) scanf("%d", &v[i]), id[v[i]] = i; build(1, n+2, 0, 1); root = (n+3) >> 1; char cmd[10]; int S, T; for(int i = 1; i <= m; i++) { scanf("%s%d", cmd, &S); switch(cmd[0]) { case 'T': move(S, -INF); break; case 'B': move(S, INF); break; case 'I': scanf("%d", &T); move(S, T); break; case 'A': splay(id[S], root); printf("%d\n", s[ch[id[S]][0]]-1); break; case 'Q': printf("%d\n", v[find(root, S+1)]); break; } } return 0; }