牛客网多校训练第三场 C - Shuffle Cards(Splay / rope)
链接:
https://www.nowcoder.com/acm/contest/141/C
题意:
给出一个n个元素的序列(1,2,...,n)和m个操作(1≤n,m≤1e5),
每个操作给出两个数p和s(1≤pi≤n,1≤si≤n-pi+1),表示把序列中从p开始的s个数移到最前面,
例如序列[1,2,3,4,5]在p=2,s=3时变成序列[2,3,4,1,5],输出最后的序列。
分析:
对于每个操作,直接把序列拆成三个部分,再重新拼接一下就行。
可以用Splay或rope来快速完成这个操作。
代码:
Splay版
1 #include <cstdio> 2 3 /// 元素编号从1开始,使用前要初始化(init) 4 template<typename type> // 元素类型 5 class SplaySequence { 6 public: 7 struct Node { 8 Node *ch[2]; // 左右子树 9 int s; // 结点数 10 type v; // 值 11 bool flip; // 反转标记 12 int cmp(int k) const { 13 int d = k - ch[0]->s; 14 if(d == 1) return -1; 15 return d <= 0 ? 0 : 1; 16 } 17 void maintain() { 18 s = ch[0]->s + ch[1]->s + 1; 19 } 20 void pushdown() { 21 if(!flip) return; 22 flip = false; 23 Node* t = ch[0]; ch[0] = ch[1]; ch[1] = t; 24 ch[0]->flip ^= 1; 25 ch[1]->flip ^= 1; 26 } 27 }; 28 Node* root; // 根结点 29 // 将序列分裂成s[1...k]和s[k+1...o->s],分别放于left和right 30 void split(Node* o, int k, Node* &left, Node* &right) { 31 if(k <= 0) { 32 left = null; 33 right = o; 34 return; 35 } 36 splay(o, k); 37 left = o; 38 right = o->ch[1]; 39 o->ch[1] = null; 40 left->maintain(); 41 } 42 // 合并left和right。 43 Node* merge(Node* left, Node* right) { 44 if(left == null) return right; 45 splay(left, left->s); 46 left->ch[1] = right; 47 left->maintain(); 48 return left; 49 } 50 // 在第k个元素之后插入v 51 void insert(int k, type v) { 52 Node *left, *right, *mid = newNode(); 53 mid->ch[0] = mid->ch[1] = null; 54 mid->s = 1; 55 mid->v = v; 56 mid->flip = false; 57 split(root, k, left, right); 58 root = merge(merge(left, mid), right); 59 } 60 // 删除第k个元素 61 void erase(int k) { 62 Node *left, *right, *mid, *o; 63 split(root, k-1, left, o); 64 split(o, 1, mid, right); 65 root = merge(left, right); 66 delNode(mid); 67 } 68 // 将第L...R个元素反转 69 void reverse(int L, int R) { 70 Node *left, *right, *mid, *o; 71 split(root, L-1, left, o); 72 split(o, R-L+1, mid, right); 73 mid->flip ^= 1; 74 root = merge(merge(left, mid), right); 75 } 76 // 返回第k个元素 77 Node* kth(int k) { 78 Node* o = root; 79 while(o != null) { 80 int d = o->cmp(k); 81 if(d == -1) return o; 82 if(d == 1) k -= o->ch[0]->s + 1; 83 o = o->ch[d]; 84 } 85 return null; 86 } 87 // 读取第k个元素 88 type operator[](int k) { 89 return kth(k)->v; 90 } 91 // 返回元素个数 92 int size() { 93 return root->s; 94 } 95 // 初始化、清空序列 96 void init() { 97 for(int i = 0; i < MAXS; i++) stk[i] = &mem[i]; 98 top = MAXS - 1; 99 null = newNode(); 100 null->s = 0; 101 root = null; 102 } 103 // 中序遍历输出整个序列 104 void print(Node* o) { 105 if(o == null) return; 106 o->pushdown(); 107 print(o->ch[0]); 108 printf("%d ", o->v); 109 print(o->ch[1]); 110 } 111 private: 112 static const int MAXS = 1e6 + 5; // 最大结点数 113 Node* null; 114 int top; 115 Node* stk[MAXS]; 116 Node mem[MAXS]; 117 void rotate(Node* &o, int d) { 118 Node* k = o->ch[d^1]; 119 o->ch[d^1] = k->ch[d]; 120 k->ch[d] = o; 121 o->maintain(); 122 k->maintain(); 123 o = k; 124 } 125 void splay(Node* &o, int k) { 126 o->pushdown(); 127 int d = o->cmp(k); 128 if(d == -1) return; 129 if(d == 1) k -= o->ch[0]->s + 1; 130 Node* p = o->ch[d]; 131 p->pushdown(); 132 int d2 = p->cmp(k); 133 int k2 = (d2 == 0 ? k : k - p->ch[0]->s - 1); 134 if(d2 != -1) { 135 splay(p->ch[d2], k2); 136 if(d == d2) rotate(o, d^1); 137 else rotate(o->ch[d], d); 138 } 139 rotate(o, d^1); 140 } 141 Node* newNode() { 142 return stk[top--]; 143 } 144 void delNode(Node* o) { 145 stk[++top] = o; 146 } 147 }; 148 149 SplaySequence<int> ss; 150 151 int main() { 152 int n, m; 153 scanf("%d%d", &n, &m); 154 ss.init(); 155 for(int i = 1; i <= n; i++) ss.insert(ss.size(), i); 156 for(int p, s, i = 0; i < m; i++) { 157 scanf("%d%d", &p, &s); 158 SplaySequence<int>::Node *left, *mid, *right, *o; 159 ss.split(ss.root, p-1, left, o); 160 ss.split(o, s, mid, right); 161 ss.root = ss.merge(ss.merge(mid, left), right); 162 } 163 ss.print(ss.root); 164 return 0; 165 }
rope版
1 #include <cstdio> 2 #include <ext/rope> 3 using namespace __gnu_cxx; 4 5 rope<int> r; 6 7 int main() { 8 int n, m; 9 scanf("%d%d", &n, &m); 10 for(int i = 1; i <= n; i++) r.push_back(i); 11 for(int p, s, i = 0; i < m; i++) { 12 scanf("%d%d", &p, &s); 13 p--; 14 r = r.substr(p, s) + r.substr(0, p) + r.substr(p+s, n-p-s); 15 } 16 for(int i = 0; i < n; i++) printf("%d ", r[i]); 17 return 0; 18 } 19 20 /* 21 r.push_back(x); // 在末尾添加x 22 r.insert(pos, x); // 在pos插入x 23 r.erase(pos, x); // 从pos开始删除x个 24 r.copy(pos, len, x); // 从pos开始到pos+len为止用x代替 25 r.replace(pos, x); // 从pos开始换成x 26 r.substr(pos, x); // 提取pos开始x个 27 r.at(x) / [x]; // 访问第x个元素 28 时间复杂度为n*(n^0.5) 29 */