splay 伸展树 代码实现
Splay 概念文章: http://blog.csdn.net/naivebaby/article/details/1357734
叉姐 数组实现: https://github.com/ftiasch/mithril/blob/master/2012-10-24/I.cpp#L43
Vani 指针实现: https://github.com/Azure-Vani/acm-icpc/blob/master/spoj/SEQ2.cpp
hdu 1890 写法: http://blog.csdn.net/fp_hzq/article/details/8087431
HH splay写法: http://www.notonlysuccess.com/index.php/splay-tree/
poj 3468 HH写法
View Code
1 /* 2 http://acm.pku.edu.cn/JudgeOnline/problem?id=3468 3 区间跟新,区间求和 4 */ 5 #include <cstdio> 6 #define keyTree (ch[ ch[root][1] ][0]) 7 const int maxn = 222222; 8 struct SplayTree{ 9 int sz[maxn]; 10 int ch[maxn][2]; 11 int pre[maxn]; 12 int root , top1 , top2; 13 int ss[maxn] , que[maxn]; 14 15 inline void Rotate(int x,int f) { 16 int y = pre[x]; 17 push_down(y); 18 push_down(x); 19 ch[y][!f] = ch[x][f]; 20 pre[ ch[x][f] ] = y; 21 pre[x] = pre[y]; 22 if(pre[x]) ch[ pre[y] ][ ch[pre[y]][1] == y ] = x; 23 ch[x][f] = y; 24 pre[y] = x; 25 push_up(y); 26 } 27 inline void Splay(int x,int goal) { 28 push_down(x); 29 while(pre[x] != goal) { 30 if(pre[pre[x]] == goal) { 31 Rotate(x , ch[pre[x]][0] == x); 32 } else { 33 int y = pre[x] , z = pre[y]; 34 int f = (ch[z][0] == y); 35 if(ch[y][f] == x) { 36 Rotate(x , !f) , Rotate(x , f); 37 } else { 38 Rotate(y , f) , Rotate(x , f); 39 } 40 } 41 } 42 push_up(x); 43 if(goal == 0) root = x; 44 } 45 inline void RotateTo(int k,int goal) {//把第k位的数转到goal下边 46 int x = root; 47 push_down(x); 48 while(sz[ ch[x][0] ] != k) { 49 if(k < sz[ ch[x][0] ]) { 50 x = ch[x][0]; 51 } else { 52 k -= (sz[ ch[x][0] ] + 1); 53 x = ch[x][1]; 54 } 55 push_down(x); 56 } 57 Splay(x,goal); 58 } 59 inline void erase(int x) {//把以x为祖先结点删掉放进内存池,回收内存 60 int father = pre[x]; 61 int head = 0 , tail = 0; 62 for (que[tail++] = x ; head < tail ; head ++) { 63 ss[top2 ++] = que[head]; 64 if(ch[ que[head] ][0]) que[tail++] = ch[ que[head] ][0]; 65 if(ch[ que[head] ][1]) que[tail++] = ch[ que[head] ][1]; 66 } 67 ch[ father ][ ch[father][1] == x ] = 0; 68 pushup(father); 69 } 70 //以上一般不修改////////////////////////////////////////////////////////////////////////////// 71 void debug() {printf("%d\n",root);Treaval(root);} 72 void Treaval(int x) { 73 if(x) { 74 Treaval(ch[x][0]); 75 printf("结点%2d:左儿子 %2d 右儿子 %2d 父结点 %2d size = %2d ,val = %2d\n",x,ch[x][0],ch[x][1],pre[x],sz[x],val[x]); 76 Treaval(ch[x][1]); 77 } 78 } 79 //以上Debug 80 81 82 //以下是题目的特定函数: 83 inline void NewNode(int &x,int c) { 84 if (top2) x = ss[--top2];//用栈手动压的内存池 85 else x = ++top1; 86 ch[x][0] = ch[x][1] = pre[x] = 0; 87 sz[x] = 1; 88 89 val[x] = sum[x] = c;/*这是题目特定函数*/ 90 add[x] = 0; 91 } 92 93 //把延迟标记推到孩子 94 inline void push_down(int x) {/*这是题目特定函数*/ 95 if(add[x]) { 96 val[x] += add[x]; 97 add[ ch[x][0] ] += add[x]; 98 add[ ch[x][1] ] += add[x]; 99 sum[ ch[x][0] ] += (long long)sz[ ch[x][0] ] * add[x]; 100 sum[ ch[x][1] ] += (long long)sz[ ch[x][1] ] * add[x]; 101 add[x] = 0; 102 } 103 } 104 //把孩子状态更新上来 105 inline void push_up(int x) { 106 sz[x] = 1 + sz[ ch[x][0] ] + sz[ ch[x][1] ]; 107 /*这是题目特定函数*/ 108 sum[x] = add[x] + val[x] + sum[ ch[x][0] ] + sum[ ch[x][1] ]; 109 } 110 111 /*初始化*/ 112 inline void makeTree(int &x,int l,int r,int f) { 113 if(l > r) return ; 114 int m = (l + r)>>1; 115 NewNode(x , num[m]); /*num[m]权值改成题目所需的*/ 116 makeTree(ch[x][0] , l , m - 1 , x); 117 makeTree(ch[x][1] , m + 1 , r , x); 118 pre[x] = f; 119 push_up(x); 120 } 121 inline void init(int n) {/*这是题目特定函数*/ 122 ch[0][0] = ch[0][1] = pre[0] = sz[0] = 0; 123 add[0] = sum[0] = 0; 124 125 root = top1 = 0; 126 //为了方便处理边界,加两个边界顶点 127 NewNode(root , -1); 128 NewNode(ch[root][1] , -1); 129 pre[top1] = root; 130 sz[root] = 2; 131 132 133 for (int i = 0 ; i < n ; i ++) scanf("%d",&num[i]); 134 makeTree(keyTree , 0 , n-1 , ch[root][1]); 135 push_up(ch[root][1]); 136 push_up(root); 137 } 138 /*更新*/ 139 inline void update( ) {/*这是题目特定函数*/ 140 int l , r , c; 141 scanf("%d%d%d",&l,&r,&c); 142 RotateTo(l-1,0); 143 RotateTo(r+1,root); 144 add[ keyTree ] += c; 145 sum[ keyTree ] += (long long)c * sz[ keyTree ]; 146 } 147 /*询问*/ 148 inline void query() {/*这是题目特定函数*/ 149 int l , r; 150 scanf("%d%d",&l,&r); 151 RotateTo(l-1 , 0); 152 RotateTo(r+1 , root); 153 printf("%lld\n",sum[keyTree]); 154 } 155 156 157 /*这是题目特定变量*/ 158 int num[maxn]; 159 int val[maxn]; 160 int add[maxn]; 161 long long sum[maxn]; 162 }spt; 163 164 165 int main() { 166 int n , m; 167 scanf("%d%d",&n,&m); 168 spt.init(n); 169 while(m --) { 170 char op[2]; 171 scanf("%s",op); 172 if(op[0] == 'Q') { 173 spt.query(); 174 } else { 175 spt.update(); 176 } 177 } 178 return 0; 179 }
叉姐
View Code
1 #include <cstdio> 2 #include <cstring> 3 #include <vector> 4 #include <climits> 5 #include <algorithm> 6 using namespace std; 7 8 const int N = 200000; 9 const int M = 1 + (N << 1); 10 const int EMPTY = M - 1; 11 12 const int MOD = 99990001; 13 14 int nodeCount, type[M], parent[M], children[M][2], id[M]; 15 16 int scale[M], delta[M], weight[M], size[M], minimum[M]; 17 18 void update(int x) { 19 size[x] = size[children[x][0]] + 1 + size[children[x][1]]; 20 minimum[x] = min(min(minimum[children[x][0]], minimum[children[x][1]]), id[x]); 21 } 22 23 void modify(int x, int k, int b) { 24 weight[x] = ((long long)k * weight[x] + b) % MOD; 25 scale[x] = (long long)k * scale[x] % MOD; 26 delta[x] = ((long long)k * delta[x] + b) % MOD; 27 } 28 29 void pushDown(int x) { 30 for (int i = 0; i < 2; ++ i) { 31 if (children[x][i] != EMPTY) { 32 modify(children[x][i], scale[x], delta[x]); 33 } 34 } 35 scale[x] = 1; 36 delta[x] = 0; 37 } 38 39 void rotate(int x) { 40 int t = type[x]; 41 int y = parent[x]; 42 int z = children[x][1 ^ t]; 43 type[x] = type[y]; 44 parent[x] = parent[y]; 45 if (type[x] != 2) { 46 children[parent[x]][type[x]] = x; 47 } 48 type[y] = 1 ^ t; 49 parent[y] = x; 50 children[x][1 ^ t] = y; 51 if (z != EMPTY) { 52 type[z] = t; 53 parent[z] = y; 54 } 55 children[y][t] = z; 56 update(y); 57 } 58 59 void splay(int x) { 60 if (x == EMPTY) { 61 return; 62 } 63 vector <int> stack(1, x); 64 for (int i = x; type[i] != 2; i = parent[i]) { 65 stack.push_back(parent[i]); 66 } 67 while (!stack.empty()) { 68 pushDown(stack.back()); 69 stack.pop_back(); 70 } 71 while (type[x] != 2) { 72 int y = parent[x]; 73 if (type[x] == type[y]) { 74 rotate(y); 75 } else { 76 rotate(x); 77 } 78 if (type[x] == 2) { 79 break; 80 } 81 rotate(x); 82 } 83 update(x); 84 } 85 86 int goLeft(int x) { 87 while (children[x][0] != EMPTY) { 88 x = children[x][0]; 89 } 90 return x; 91 } 92 93 int join(int x, int y) { 94 if (x == EMPTY || y == EMPTY) { 95 return x != EMPTY ? x : y; 96 } 97 y = goLeft(y); 98 splay(y); 99 splay(x); 100 type[x] = 0; 101 parent[x] = y; 102 children[y][0] = x; 103 update(y); 104 return y; 105 } 106 107 pair <int, int> split(int x) { 108 splay(x); 109 int a = children[x][0]; 110 int b = children[x][1]; 111 children[x][0] = children[x][1] = EMPTY; 112 if (a != EMPTY) { 113 type[a] = 2; 114 parent[a] = EMPTY; 115 } 116 if (b != EMPTY) { 117 type[b] = 2; 118 parent[b] = EMPTY; 119 } 120 return make_pair(a, b); 121 } 122 123 int newNode(int init, int vid) { 124 int x = nodeCount ++; 125 type[x] = 2; 126 parent[x] = children[x][0] = children[x][1] = EMPTY; 127 id[x] = vid; 128 weight[x] = init; 129 scale[x] = 1; 130 delta[x] = 0; 131 update(x); 132 return x; 133 } 134 135 int n; 136 int edgeCount, firstEdge[N], to[M], nextEdge[M], initWeight[N], position[M]; 137 138 int root; 139 140 void addEdge(int u, int v) { 141 to[edgeCount] = v; 142 nextEdge[edgeCount] = firstEdge[u]; 143 firstEdge[u] = edgeCount ++; 144 } 145 146 void dfs(int p, int u) { 147 for (int iter = firstEdge[u]; iter != -1; iter = nextEdge[iter]) { 148 int v = to[iter]; 149 if (v != p) { 150 position[iter] = nodeCount; 151 root = join(root, newNode(initWeight[iter >> 1], min(u, v))); 152 dfs(u, v); 153 position[iter ^ 1] = nodeCount; 154 root = join(root, newNode(initWeight[iter >> 1], min(u, v))); 155 } 156 } 157 } 158 159 int getRank(int x) { // 1-based 160 splay(x); 161 return size[children[x][0]] + 1; 162 } 163 164 void print(int root) { 165 if (root != EMPTY) { 166 printf("[ "); 167 print(children[root][0]); 168 printf(" %d ", root); 169 print(children[root][1]); 170 printf(" ]"); 171 } 172 } 173 174 int main() { 175 size[EMPTY] = 0; 176 minimum[EMPTY] = INT_MAX; 177 parent[EMPTY] = 2; 178 scanf("%d", &n); 179 edgeCount = 0; 180 memset(firstEdge, -1, sizeof(firstEdge)); 181 for (int i = 0; i < n - 1; ++ i) { 182 int a, b; 183 scanf("%d%d%d", &a, &b, initWeight + i); 184 a --; 185 b --; 186 addEdge(a, b); 187 addEdge(b, a); 188 } 189 nodeCount = 0; 190 root = EMPTY; 191 dfs(-1, 0); 192 for (int i = 0; i < n - 1; ++ i) { 193 int id; 194 scanf("%d", &id); 195 id --; 196 197 int a = position[id << 1]; 198 int b = position[(id << 1) ^ 1]; 199 if (getRank(a) > getRank(b)) { 200 swap(a, b); 201 } 202 splay(a); 203 204 int output = weight[a]; 205 printf("%d\n", output); 206 fflush(stdout); 207 208 pair <int, int> ret1 = split(a); 209 pair <int, int> ret2 = split(b); 210 int x = ret1.first; 211 int y = ret2.first; 212 int z = ret2.second; 213 x = join(z, x); 214 splay(x); 215 splay(y); 216 if (size[x] > size[y]) { 217 swap(x, y); 218 } 219 if (size[x] == size[y] && minimum[x] > minimum[y]) { 220 swap(x, y); 221 } 222 modify(x, output, 0); 223 modify(y, 1, output); 224 } 225 return 0; 226 }
spoj SEQ2
Vani
View Code
1 #include <cstdio> 2 #include <cctype> 3 #include <algorithm> 4 #include <cstring> 5 6 using namespace std; 7 8 namespace Solve { 9 const int MAXN = 500010; 10 const int inf = 500000000; 11 12 char BUF[50000000], *pos = BUF; 13 inline int ScanInt(void) { 14 int r = 0, d = 0; 15 while (!isdigit(*pos) && *pos != '-') pos++; 16 if (*pos != '-') r = *pos - 48; else d = 1; pos++; 17 while ( isdigit(*pos)) r = r * 10 + *pos++ - 48; 18 return d ? -r : r; 19 } 20 inline void ScanStr(char *st) { 21 int l = 0; 22 while (!(isupper(*pos) || *pos == '-')) pos++; 23 st[l++] = *pos++; 24 while (isupper(*pos) || *pos == '-') st[l++] = *pos++; st[l] = 0; 25 } 26 27 struct Node { 28 Node *ch[2], *p; 29 int v, lmax, rmax, m, same, rev, sum, size; 30 inline bool dir(void) {return this == p->ch[1];} 31 inline void SetC(Node *x, bool d) {ch[d] = x, x->p = this;} 32 inline void Update(void) { 33 Node *L = ch[0], *R = ch[1]; 34 size = L->size + R->size + 1; 35 m = max(L->m, R->m); 36 m = max(m, L->rmax + v + R->lmax); 37 lmax = max(L->lmax, L->sum + v + R->lmax); 38 rmax = max(R->rmax, R->sum + v + L->rmax); 39 sum = L->sum + R->sum + v; 40 } 41 inline void Rev(void) { 42 if (v == -inf) return; 43 rev ^= 1; 44 swap(ch[0], ch[1]); 45 swap(lmax, rmax); 46 } 47 inline void Same(int u) { 48 if (v == -inf) return; 49 same = u; 50 sum = u * size; 51 if (sum > 0) lmax = rmax = m = sum; else lmax = 0, rmax = 0, m = u; 52 v = u; 53 } 54 inline void Down(void) { 55 if (rev) { 56 ch[0]->Rev(), ch[1]->Rev(); 57 rev = 0; 58 } 59 if (same != -inf) { 60 ch[0]->Same(same), ch[1]->Same(same); 61 same = -inf; 62 } 63 } 64 } Tnull, *null = &Tnull; 65 66 class Splay {public: 67 Node *root; 68 inline void rotate(Node *x) { 69 Node *p = x->p; bool d = x->dir(); 70 p->Down(); x->Down(); 71 p->p->SetC(x, p->dir()); 72 p->SetC(x->ch[!d], d); 73 x->SetC(p, !d); 74 p->Update(); 75 } 76 inline void splay(Node *x, Node *G) { 77 if (G == null) root = x; 78 while (x->p != G) { 79 if (x->p->p == G) {rotate(x); break;} 80 else {if (x->dir() == x->p->dir()) rotate(x->p), rotate(x); else rotate(x), rotate(x);} 81 } 82 x->Update(); 83 } 84 inline Node *Select(int k) { 85 Node *t = root; 86 while (t->Down(), t->ch[0]->size + 1 != k) { 87 if (k > t->ch[0]->size + 1) k -= t->ch[0]->size + 1, t = t->ch[1]; 88 else t = t->ch[0]; 89 } 90 splay(t, null); 91 return t; 92 } 93 inline Node *getInterval(int l, int r) { 94 Node *L = Select(l), *R = Select(r + 2); 95 splay(L, null); splay(R, L); 96 L->Down(); R->Down(); 97 return R; 98 } 99 inline void Insert(int pos, Node *x) { 100 Node *now = getInterval(pos + 1, pos); 101 now->SetC(x, 0); 102 now->Update(); root->Update(); 103 } 104 inline void Delete(int l, int r) { 105 Node *now = getInterval(l, r); 106 now->ch[0] = null; 107 now->Update(); root->Update(); 108 } 109 inline void Make(int l, int r, int c) { 110 Node *now = getInterval(l, r); 111 now->ch[0]->Same(c); 112 now->Update(); root->Update(); 113 } 114 inline void Reverse(int l, int r) { 115 Node *now = getInterval(l, r); 116 now->ch[0]->Rev(); 117 now->Update(); root->Update(); 118 } 119 inline int Sum(int l, int r) { 120 Node *now = getInterval(l, r); 121 root->Down(); now->Down(); 122 return now->ch[0]->sum; 123 } 124 inline int maxSum(int l, int r) { 125 Node *now = getInterval(l, r); 126 root->Down(); now->Down(); 127 return now->ch[0]->m; 128 } 129 inline Node* Renew(int c) { 130 Node *ret = new Node; 131 ret->ch[0] = ret->ch[1] = ret->p = null; ret->size = 1; 132 ret->Same(c); ret->same = -inf; 133 return ret; 134 } 135 inline Node* Build(int l, int r, int *a) { 136 if (l > r) return null; 137 int mid = (l + r) >> 1; 138 Node *ret = Renew(a[mid]); 139 ret->ch[0] = Build(l, mid - 1, a); 140 ret->ch[1] = Build(mid + 1, r, a); 141 ret->ch[0]->p = ret->ch[1]->p = ret; 142 ret->Update(); 143 return ret; 144 } 145 inline void P(Node *t) { 146 if (t == null) return; 147 t->Down(); t->Update(); 148 P(t->ch[0]); 149 printf("%d ", t->v); 150 P(t->ch[1]); 151 } 152 }T; 153 154 155 int a[MAXN]; char ch[10]; 156 157 inline void solve(void) { 158 fread(BUF, 1, 50000000, stdin); 159 null->same = null->m = null->v = -inf; 160 int kase = ScanInt(); 161 while (kase--) { 162 int n = ScanInt(), m = ScanInt(); 163 for (int i = 1; i <= n; i++) a[i] = ScanInt(); 164 T.root = T.Build(0, n + 1, a); 165 for (int i = 1; i <= m; i++) { 166 ScanStr(ch); 167 if (strcmp(ch, "INSERT") == 0) { 168 int pos = ScanInt(), t = ScanInt(); 169 for (int j = 1; j <= t; j++) a[j] = ScanInt(); 170 Node *tmp = T.Build(1, t, a); 171 T.Insert(pos, tmp); 172 } 173 if (strcmp(ch, "DELETE") == 0) { 174 int l = ScanInt(), r = ScanInt(); r = l + r - 1; 175 T.Delete(l, r); 176 } 177 if (strcmp(ch, "MAKE-SAME") == 0) { 178 int l = ScanInt(), r = ScanInt(), c = ScanInt(); r = l + r - 1; 179 T.Make(l, r, c); 180 } 181 if (strcmp(ch, "REVERSE") == 0) { 182 int l = ScanInt(), r = ScanInt(); r = l + r - 1; 183 T.Reverse(l, r); 184 } 185 if (strcmp(ch, "GET-SUM") == 0) { 186 int l = ScanInt(), r = ScanInt(); r = l + r - 1; 187 int ret = T.Sum(l, r); 188 printf("%d\n", ret); 189 } 190 if (strcmp(ch, "MAX-SUM") == 0) { 191 int ret = T.maxSum(1, T.root->size - 2); 192 printf("%d\n", ret); 193 } 194 } 195 } 196 } 197 } 198 199 int main(void) { 200 freopen("in", "r", stdin); 201 Solve::solve(); 202 return 0; 203 }