[BZOJ1500][NOI2005]维修数列
[BZOJ1500][NOI2005]维修数列
试题描述
输入
输入的第1 行包含两个数N 和M(M ≤20 000),N 表示初始时数列中数的个数,M表示要进行的操作数目。
第2行包含N个数字,描述初始时的数列。
以下M行,每行一条命令,格式参见问题描述中的表格。
任何时刻数列中最多含有500 000个数,数列中任何一个数字均在[-1 000, 1 000]内。
插入的数字总数不超过4 000 000个,输入文件大小不超过20MBytes。
输出
对于输入数据中的GET-SUM和MAX-SUM操作,向输出文件依次打印结果,每个答案(数字)占一行。
输入示例
9 8 2 -6 3 5 1 -5 -3 6 3 GET-SUM 5 4 MAX-SUM INSERT 8 3 -5 7 2 DELETE 12 1 MAKE-SAME 3 3 2 REVERSE 3 6 GET-SUM 5 4 MAX-SUM
输出示例
-1 10 1 10
数据规模及约定
见“输入”
题解
又一道裸题。但是这题维护的东西比较恶心:需要维护节点的值 v、子树大小 siz、子树权值和 sum、子树最大前缀和 ml、子树最大后缀和 mr、子树最大连续和 ms、权值懒标记 setv 以及反转标记 rev。然后注意如果一个子树 rev = 1(即打了反转标记),则它的 ml 和 mr 需要互换。
#include <iostream> #include <cstdio> #include <cstdlib> #include <cstring> #include <cctype> #include <algorithm> using namespace std; int read() { int x = 0, f = 1; char c = getchar(); while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); } while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); } return x * f; } #define maxn 500010 #define LL long long #define NoSign -2333 #define oo (1ll << 60) struct Node { int v, siz; LL sum, ml, mr, ms, setv; bool rev; Node() {} Node(int _): v(_), setv(NoSign), rev(0) {} } ns[maxn]; int rt, ToT, fa[maxn], ch[2][maxn], rec[maxn], cc; bool hs(int o) { return ns[o].setv != NoSign; } void maintain(int o) { ns[o].siz = 1; ns[o].sum = ns[o].v; ns[o].ms = ns[o].v; int l = ch[0][o], r = ch[1][o]; for(int i = 0; i < 2; i++) if(ch[i][o]) ns[o].siz += ns[ch[i][o]].siz, ns[o].sum += !hs(ch[i][o]) ? ns[ch[i][o]].sum : ns[ch[i][o]].setv * ns[ch[i][o]].siz, ns[o].ms = max(ns[o].ms, !hs(ch[i][o]) ? ns[ch[i][o]].ms : max(ns[ch[i][o]].setv, ns[ch[i][o]].setv * ns[ch[i][o]].siz)); int tll = ns[l].ml, tlr = ns[l].mr, trl = ns[r].ml, trr = ns[r].mr, tls = ns[l].sum, trs = ns[r].sum; if(hs(l)) ns[l].ml = ns[l].mr = max(ns[l].setv, ns[l].setv * ns[l].siz), ns[l].sum = ns[l].setv * ns[l].siz; if(hs(r)) ns[r].ml = ns[r].mr = max(ns[r].setv, ns[r].setv * ns[r].siz), ns[r].sum = ns[r].setv * ns[r].siz; if(ns[l].rev) swap(ns[l].ml, ns[l].mr); if(ns[r].rev) swap(ns[r].ml, ns[r].mr); if(l) ns[o].ml = ns[l].ml; else ns[o].ml = -oo; ns[o].ml = max(ns[o].ml, (l ? ns[l].sum : 0) + ns[o].v); if(r) ns[o].ml = max(ns[o].ml, (l ? ns[l].sum : 0) + ns[o].v + ns[r].ml); if(r) ns[o].mr = ns[r].mr; else ns[o].mr = -oo; ns[o].mr = max(ns[o].mr, (r ? ns[r].sum : 0) + ns[o].v); if(l) ns[o].mr = max(ns[o].mr, (r ? ns[r].sum : 0) + ns[o].v + ns[l].mr); if(l) ns[o].ms = max(ns[o].ms, ns[l].mr + max(ns[o].v, 0)); if(r) ns[o].ms = max(ns[o].ms, ns[r].ml + max(ns[o].v, 0)); if(l && r) ns[o].ms = max(ns[o].ms, ns[l].mr + ns[o].v + ns[r].ml); ns[l].ml = tll; ns[l].mr = tlr; ns[r].ml = trl; ns[r].mr = trr; ns[l].sum = tls; ns[r].sum = trs; return ; } int getnode() { if(cc) return rec[cc--]; return ++ToT; } int val[maxn], cv; void build(int& o, int l, int r) { if(l > r) return ; int mid = l + r >> 1; ns[o = getnode()] = Node(val[mid]); build(ch[0][o], l, mid - 1); build(ch[1][o], mid + 1, r); if(ch[0][o]) fa[ch[0][o]] = o; if(ch[1][o]) fa[ch[1][o]] = o; maintain(o); return ; } void pushdown(int o) { if(hs(o)) { LL& st = ns[o].setv; ns[o].v = st; for(int i = 0; i < 2; i++) if(ch[i][o]) ns[ch[i][o]].setv = st; st = NoSign; } if(ns[o].rev) { bool& rv = ns[o].rev; for(int i = 0; i < 2; i++) if(ch[i][o]) ns[ch[i][o]].rev ^= rv; swap(ch[0][o], ch[1][o]); rv = 0; } return maintain(o); } void rotate(int u) { int y = fa[u], z = fa[y], l = 0, r = 1; if(z) ch[ch[1][z]==y][z] = u; if(ch[1][y] == u) swap(l, r); fa[u] = z; fa[y] = u; fa[ch[r][u]] = y; ch[l][y] = ch[r][u]; ch[r][u] = y; maintain(y); maintain(u); return ; } int S[maxn], top; void splay(int u) { int t = u; S[++top] = t; while(fa[t]) t = fa[t], S[++top] = t; while(top) pushdown(S[top--]); while(fa[u]) { int y = fa[u], z = fa[y]; if(z) { if(ch[0][y] == u ^ ch[0][z] == y) rotate(u); else rotate(y); } rotate(u); } return ; } int split(int u) { if(!u) return 0; splay(u); int tmp = ch[1][u]; fa[tmp] = 0; ch[1][u] = 0; maintain(u); return tmp; } int merge(int a, int b) { if(!a) return maintain(b), b; if(!b) return maintain(a), a; pushdown(a); while(ch[1][a]) a = ch[1][a], pushdown(a); splay(a); ch[1][a] = b; fa[b] = a; return maintain(a), a; } int qkth(int o, int k) { if(!o) return 0; pushdown(o); int ls = ch[0][o] ? ns[ch[0][o]].siz : 0; if(k == ls + 1) return o; if(k > ls + 1) return qkth(ch[1][o], k - ls - 1); return qkth(ch[0][o], k); } int nsize; int Find(int k) { if(!nsize) return 0; while(fa[rt]) rt = fa[rt]; return qkth(rt, k); } void Split(int ql, int qr, int& lrt, int& mrt, int& rrt) { lrt = Find(ql - 1); mrt = Find(qr); split(lrt); rrt = split(mrt); return ; } void Merge(int lrt, int mrt, int rrt) { mrt = merge(lrt, mrt); merge(mrt, rrt); return ; } void recycle(int& o) { if(!o) return ; recycle(ch[0][o]); recycle(ch[1][o]); fa[o] = 0; rec[++cc] = o; o = 0; return ; } void Ins(int pos) { int lrt, mrt = 0, rrt; lrt = Find(pos); if(lrt) rrt = split(lrt); else rrt = Find(1), splay(rrt); build(mrt, 1, cv); Merge(lrt, mrt, rrt); return ; } void Del(int ql, int qr) { int lrt, mrt, rrt; Split(ql, qr, lrt, mrt, rrt); recycle(mrt); Merge(lrt, mrt, rrt); rt = max(lrt, rrt); return ; } void Setv(int ql, int qr, int v) { int lrt, mrt, rrt; Split(ql, qr, lrt, mrt, rrt); ns[mrt].setv = v; Merge(lrt, mrt, rrt); return ; } void Rev(int ql, int qr) { int lrt, mrt, rrt; Split(ql, qr, lrt, mrt, rrt); ns[mrt].rev ^= 1; Merge(lrt, mrt, rrt); return ; } LL Sum(int ql, int qr) { if(ql > qr) return 0; int lrt, mrt, rrt; Split(ql, qr, lrt, mrt, rrt); LL ans = !hs(mrt) ? ns[mrt].sum : ns[mrt].setv * ns[mrt].siz; Merge(lrt, mrt, rrt); return ans; } LL MxSum() { while(fa[rt]) rt = fa[rt]; LL ans = !hs(rt) ? ns[rt].ms : max(ns[rt].setv, ns[rt].setv * ns[rt].siz); return ans; } int main() { int n = read(), q = read(); for(int i = 1; i <= n; i++) val[i] = read(); build(rt, 1, n); nsize = n; int cnt = 0, tq = q; while(q--) { char cmd[20]; scanf("%s", cmd); if(cmd[0] == 'I') { int pos = read(); cv = read(); for(int i = 1; i <= cv; i++) val[i] = read(); Ins(pos); nsize += cv; } if(cmd[0] == 'D') { int ql = read(), qr = min(nsize, ql + read() - 1); Del(ql, qr); nsize -= qr - ql + 1; } if(cmd[0] == 'M' && cmd[2] == 'K') { int ql = read(), qr = min(nsize, ql + read() - 1), v = read(); Setv(ql, qr, v); } if(cmd[0] == 'R') { int ql = read(), qr = min(nsize, ql + read() - 1); Rev(ql, qr); } if(cmd[0] == 'G') { int ql = read(), qr = min(nsize, ql + read() - 1); printf("%lld\n", Sum(ql, qr)); cnt++; } if(cmd[0] == 'M' && cmd[2] == 'X') printf("%lld\n", MxSum()), cnt++; } return 0; }