【bzoj3217】ALOEXT 替罪羊树套Trie树
题目描述
taorunz平时最喜欢的东西就是可移动存储器了……只要看到别人的可移动存储器,他总是用尽一切办法把它里面的东西弄到手。
突然有一天,taorunz来到了一个密室,里面放着一排可移动存储器,存储器里有非常珍贵的OI资料……不过比较特殊的是,每个存储器上都写着一个非负整数。taorunz很高兴,要把所有的存储器都拿走(taorunz的智商高达500,他一旦弄走了这里的所有存储器,在不久到来的AHOI和NOI中……你懂的)。不过这时有一个声音传来:“你只能拿走这里的一个存储器,而且还不能直接拿,你需要指定一段区间[l, r],满足l<r,然后在第l个和第r个存储器之间选一个拿走,你能获得的知识增加量等于区间[l, r]中所有存储器上写的整数的次大值与你拿走的这个存储器上写的整数作按位异或运算的结果。”
问题是,这里的可移动存储器数量太多,而且,它们还在不断地发生变化,有时候天上会掉下来一个新的存储器,并插入到这一排存储器中,有时候某个存储器会不明原因消失,有时候某个存储器上写的整数变化了。taorunz虽然智商很高,但也无法应对如此快的变化,他指定了许多段区间,让你帮他找出如果在这个区间中拿走存储器,他能获得的最多的知识是多少。
输入
第一行两个整数N、M,表示一开始的存储器数和后面发生的事件数。
第二行N个非负整数,表示一开始从左到右每个存储器上写的数字。注意,存储器从0开始编号,也就是最左边的存储器是第0个。
接下来M行,每行描述一个事件,有4种可能的事件。
(1)I x y:表示天上掉下来一个写着数字y的存储器,并插入到原来的第x个存储器之前,如果x等于原来存储器的个数,则插入到末尾;
(2)D x:表示第x个存储器消失;
(3)C x y:表示第x个存储器上写的数字变为y;
(4)F l r:表示taorunz指定区间[l, r],让你告诉他最多能获得多少知识。
注意,本题强制在线,也就是事件中出现的所有数字都进行了加密,数字s表示的真实值是
对于I、D、C事件中的x及F事件中的l、r:(s+last_ans) mod n0;
对于I、C事件中的y:(s+last_ans) mod 1048576。
其中n0为目前存储器个数,last_ans为上一个F事件的结果,如果前面尚未发生F事件,则last_ans=0。
输出
对于每个F事件,输出结果。
样例输入
5 10
2 6 3 8 7
F 1 4
I 2 1048565
I 0 1048566
D 3
F 3 0
I 3 1048569
D 5
C 1 1048570
F 1 2
F 2 1
样例输出
15
7
4
7
题解
替罪羊树套Trie树
题目要求支持定点插入,所以需要使用平衡树;而又要求出最大异或值,显然还要使用Trie树贪心。
考虑把它们套起来,那么平衡树需要放在外层,所以选择替罪羊树。
于是本题就与 bzoj3065 差不多,对替罪羊树的每一个节点,维护它本身和它子树所有权值的01Trie树以及次大值,查询时直接提取所有区间然后按位贪心求最大异或即可。
但是本题与那道题不同的是,本题有删除操作,所以替罪羊树还需要支持删除。
具体方法是:和普通BST是一样的删除方法,如果一个节点存在某子树为空则把这棵子树提到该节点的位置,否则寻找这个点的后继节点(右子树中第一个节点),用这个节点替换原来的节点。由于删除带来的子树大小变化使得势能分析带来的结果不正确,因此删除时不能重构树,否则时间复杂度会出错。解决方法是:对于每个点维护一个maxsize,即子树大小的历史最大值。如果这个历史最大值超过父亲节点的历史最大值的某比例便重构,重构的时候把历史最大值也重新赋值。
时间复杂度应该是$O(20n\log^2n)=O(能过)$。
本题同样需要内存回收,而且数组要开到$2.5*10^7$左右才能够通过。
为了方便内存回收使用了递归版的Trie树插入所以代码跑得慢一些= =
另外解释下代码中pushup中的if:由于我删除的姿势问题,使得在某些特殊情况下会pushup到0节点,导致它的size变为1,然后就全局错误。因此防止这种情况加了判断(于是代码就更慢了= =)
#include <queue> #include <cstdio> #include <utility> #include <algorithm> using namespace std; typedef pair<int , int> pr; struct scg { int ls , rs , si , ms , w , wr , tr; pr mx; }a[200010]; struct trie { int c[2] , si; }b[25000010]; queue<int> q; int n , root , pos[200010] , tot , qr[1000] , qt; char str[5]; void update(int v , int d , int a , int &x) { if(!x) x = q.front() , q.pop(); b[x].si += a; bool t = v & (1 << d); if(~d) update(v , d - 1 , a , b[x].c[t]); if(!b[x].si) q.push(x) , x = 0; } void del(int &x) { if(!x) return; del(b[x].c[0]) , del(b[x].c[1]); b[x].si = 0 , q.push(x) , x = 0; } int query(int v , int d) { if(d == -1) return 0; int i; bool t = v & (1 << d); for(i = 1 ; i <= qt ; i ++ ) if(b[qr[i]].c[t ^ 1]) break; if(i <= qt) { for(i = 1 ; i <= qt ; i ++ ) qr[i] = b[qr[i]].c[t ^ 1]; return query(v , d - 1) + (1 << d); } else { for(i = 1 ; i <= qt ; i ++ ) qr[i] = b[qr[i]].c[t]; return query(v , d - 1); } } inline pr getmx(pr a , pr b) { if(a.first > b.first) return pr(a.first , max(b.first , max(a.second , b.second))); else return pr(b.first , max(a.first , max(a.second , b.second))); } inline void pushup(int x) { if(!x) return; a[x].si = a[a[x].ls].si + a[a[x].rs].si + 1 , a[x].ms = max(a[x].ms , a[x].si); a[x].mx = getmx(pr(a[x].w , -1 << 30) , getmx(a[a[x].ls].mx , a[a[x].rs].mx)); } int build(int l , int r) { if(l > r) return 0; int mid = (l + r) >> 1 , i; for(i = l ; i <= r ; i ++ ) update(a[pos[i]].w , 20 , 1 , a[pos[mid]].tr); a[pos[mid]].ls = build(l , mid - 1) , a[pos[mid]].rs = build(mid + 1 , r); pushup(pos[mid]); return pos[mid]; } void dfs(int &x) { if(!x) return; dfs(a[x].ls) , pos[++tot] = x , dfs(a[x].rs); a[x].si = a[x].ms = 0 , del(a[x].tr) , x = 0; } void insert(int p , int v , int &x , bool flag) { if(!x) { x = ++n , a[x].w = v , update(v , 20 , 1 , a[x].wr) , update(v , 20 , 1 , a[x].tr) , pushup(x); return; } bool tag; update(v , 20 , 1 , a[x].tr); if(p <= a[a[x].ls].si) tag = (max(a[a[x].ls].ms , a[a[x].ls].si + 1) * 4 > max(a[x].ms , a[x].si + 1) * 3) , insert(p , v , a[x].ls , tag || flag); else tag = (max(a[a[x].rs].ms , a[a[x].rs].si + 1) * 4 > max(a[x].ms , a[x].si + 1) * 3) , insert(p - a[a[x].ls].si - 1 , v , a[x].rs , tag || flag); pushup(x); if(tag && !flag) tot = 0 , dfs(x) , x = build(1 , tot); } int find(int p , int x) { if(p <= a[a[x].ls].si) return find(p , a[x].ls); else if(p > a[a[x].ls].si + 1) return find(p - a[a[x].ls].si - 1 , a[x].rs); else return a[x].w; } void erase(int p , int v , int &x) { update(v , 20 , -1 , a[x].tr); if(p <= a[a[x].ls].si) erase(p , v , a[x].ls); else if(p > a[a[x].ls].si + 1) erase(p - a[a[x].ls].si - 1 , v , a[x].rs); else { if(!a[x].ls || !a[x].rs) x = a[x].ls + a[x].rs; else { int t = a[x].rs; while(a[t].ls) t = a[t].ls; erase(1 , a[t].w , a[x].rs); a[t].ls = a[x].ls , a[t].rs = a[x].rs , a[t].tr = a[x].tr , x = t; } } pushup(x); } void modify(int p , int v1 , int v2 , int x) { update(v1 , 20 , -1 , a[x].tr) , update(v2 , 20 , 1 , a[x].tr); if(p <= a[a[x].ls].si) modify(p , v1 , v2 , a[x].ls); else if(p > a[a[x].ls].si + 1) modify(p - a[a[x].ls].si - 1 , v1 , v2 , a[x].rs); else del(a[x].wr) , update(v2 , 20 , 1 , a[x].wr) , a[x].w = v2; pushup(x); } pr solve(int b , int e , int x) { if(b <= 1 && e >= a[x].si) return a[x].mx; pr ans = pr(-1 << 30 , -1 << 30); if(b <= a[a[x].ls].si + 1 && e >= a[a[x].ls].si + 1) ans = getmx(ans , pr(a[x].w , -1 << 30)); if(b <= a[a[x].ls].si) ans = getmx(ans , solve(b , e , a[x].ls)); if(e > a[a[x].ls].si + 1) ans = getmx(ans , solve(b - a[a[x].ls].si - 1 , e - a[a[x].ls].si - 1 , a[x].rs)); return ans; } void split(int b , int e , int x) { if(!x) return; if(b <= 1 && e >= a[x].si) { qr[++qt] = a[x].tr; return; } if(b <= a[a[x].ls].si + 1 && e >= a[a[x].ls].si + 1) qr[++qt] = a[x].wr; if(b <= a[a[x].ls].si) split(b , e , a[x].ls); if(e > a[a[x].ls].si + 1) split(b - a[a[x].ls].si - 1 , e - a[a[x].ls].si - 1 , a[x].rs); } int main() { int m , i , last = 0 , x , y; for(i = 1 ; i <= 25000000 ; i ++ ) q.push(i); scanf("%d%d" , &n , &m); for(i = 1 ; i <= n ; i ++ ) scanf("%d" , &a[i].w) , update(a[i].w , 20 , 1 , a[i].wr) , pos[i] = i; root = build(1 , n); while(m -- ) { scanf("%s" , str); switch(str[0]) { case 'I': scanf("%d%d" , &x , &y) , x = (x + last) % a[root].si + 1 , y = (y + last) % 1048576 , insert(x - 1 , y , root , 0); break; case 'D': scanf("%d" , &x) , x = (x + last) % a[root].si + 1 , erase(x , find(x , root) , root); break; case 'C': scanf("%d%d" , &x , &y) , x = (x + last) % a[root].si + 1 , y = (y + last) % 1048576 , modify(x , find(x , root) , y , root); break; default: scanf("%d%d" , &x , &y) , x = (x + last) % a[root].si + 1 , y = (y + last) % a[root].si + 1 , qt = 0 , split(x , y , root) , printf("%d\n" , last = query(solve(x , y , root).second , 20)); } } return 0; }