【BZOJ】1500: [NOI2005]维修数列(splay+变态题)
http://www.lydsy.com/JudgeOnline/problem.php?id=1500
模板不打熟你确定考场上调试得出来?
首先有非常多的坑点。。。我遇到的第一个就是,如何pushup。。。。。。。。。。。。sad。。
写了一大串。。。可是感觉。。。写不下去了。。。看别人怎么写吧。。。
orz
首先这个节点代表的这个区间我们维护mxl和mxr表示整个区间从左向右和从右向左能得到的最大序列和。。
然后我无脑不思考没有用好我们的定义,写了一大串的转移。。。
其实就是几个字。。
void pushup() { s=1+ch[0]->s+ch[1]->s; sum=ch[0]->sum+ch[1]->sum+w; int l=max(0, ch[1]->mxl), r=max(0, ch[0]->mxr); //这里一定要注意 mxl=max(ch[0]->mxl, ch[0]->sum+w+l); //不单独考虑ch[0]->sum的原因是我们假设了ch[0]->mxl是最优了 mxr=max(ch[1]->mxr, ch[1]->sum+w+r); mx=max(l+w+r, max(ch[0]->mx, ch[1]->mx)); }
555
然后有很多坑点我在注释里写。首先我的区间翻转那里我老是写错,我的getrange(l, r)是将区间l~r旋转到keytree那里,所以root是l-1,root->ch[1]是r+1,然后如果是要在某个点后边插入是getrange(pos+1, pos),一定要注意!
然后是pushdown的坑点,原来一直是写先tag后操作,现在变成了先操作后tag的形式,这点很重要!!!因为我们维护的size是整棵树的,当只有旋转到keytree下时,size不会统计到边界端点,所以整棵子树的size都是正确的!所以必须要在旋转回去前维护好所有的信息,而不是先打tag!一定要注意!!!之前一直调试不出来就是因为一直在考虑边界点怎么怎么影响。。。。。。然后各种搞不出来。。于是就这么简单的解决了orz
然后有一个地方我忘记了,,,就是每一个从树中获取信息的并向子树走的一定要pushdown!!我被select那里没有pushdown坑了好久
然后注意每一次更新子树后要pushup或者splay!!!一定要注意!!
还有询问前一定要将区间翻转后在区间进行查询!!!!一定要注意!!!
#include <cstdio> #include <cstring> #include <cmath> #include <string> #include <iostream> #include <algorithm> #include <queue> #include <set> #include <map> using namespace std; typedef long long ll; #define rep(i, n) for(int i=0; i<(n); ++i) #define for1(i,a,n) for(int i=(a);i<=(n);++i) #define for2(i,a,n) for(int i=(a);i<(n);++i) #define for3(i,a,n) for(int i=(a);i>=(n);--i) #define for4(i,a,n) for(int i=(a);i>(n);--i) #define CC(i,a) memset(i,a,sizeof(i)) #define read(a) a=getint() #define print(a) printf("%d", a) #define dbg(x) cout << (#x) << " = " << (x) << endl #define error(x) (!(x)?puts("error"):0) inline const int getint() { int r=0, k=1; char c=getchar(); for(; c<'0'||c>'9'; c=getchar()) if(c=='-') k=-1; for(; c>='0'&&c<='9'; c=getchar()) r=r*10+c-'0'; return k*r; } const int oo=~0u>>2; int n, m; struct node *null; struct node { node *ch[2], *fa; bool rev, tag; int sum, s, mx, mxl, mxr, w; void pushup() { s=1+ch[0]->s+ch[1]->s; sum=ch[0]->sum+ch[1]->sum+w; int l=max(0, ch[1]->mxl), r=max(0, ch[0]->mxr); mxl=max(ch[0]->mxl, ch[0]->sum+w+l); mxr=max(ch[1]->mxr, ch[1]->sum+w+r); mx=max(l+w+r, max(ch[0]->mx, ch[1]->mx)); } void upd(const int &k) { if(this==null) return; w=k; sum=s*k; mx=mxl=mxr=max(k, k*s); tag=true; } void upd2() { if(this==null) return; rev=!rev; swap(ch[0], ch[1]); swap(mxl, mxr); } void pushdown() { if(rev) { rev=false; ch[0]->upd2(); ch[1]->upd2(); } if(tag) { tag=false; ch[0]->upd(w); ch[1]->upd(w); } } bool d() { return fa->ch[1]==this; } void setc(node *c, bool d) { ch[d]=c; c->fa=this; } void init(const int &k) { sum=mx=mxl=mxr=w=k; } void set(int c=0) { ch[0]=ch[1]=fa=null; rev=tag=false; sum=w=mx=mxl=mxr=c; s=1; } }*root, *s[1000005]; void rot(node *x) { node *fa=x->fa; fa->pushdown(); x->pushdown(); bool d=x->d(); fa->fa->setc(x, fa->d()); fa->setc(x->ch[!d], d); x->setc(fa, !d); fa->pushup(); if(fa==root) root=x; } void splay(node *x, node *fa=null) { x->pushdown(); while(x->fa!=fa) if(x->fa->fa==fa) rot(x); else x->d()==x->fa->d()?(rot(x->fa), rot(x)):(rot(x), rot(x)); x->pushup(); } node *sel(node *x, const int &k) { x->pushdown(); int s=x->ch[0]->s; if(s==k) return x; if(k>s) return sel(x->ch[1], k-s-1); return sel(x->ch[0], k); } node *getrange(const int &l, const int &r) { splay(sel(root, l-1)); splay(sel(root, r+1), root); return root->ch[1]; } int top; node *newnode(int c=0) { node *ret; if(top) ret=s[top--]; else ret=new node; ret->set(c); return ret; } void build(node *&x, const int &l, const int &r) { if(l>r) return; x=newnode(); if(l==r) { x->init(getint()); return; } int m=(l+r)>>1; build(x->ch[0], l, m-1); x->init(getint()); build(x->ch[1], m+1, r); if(l<=m-1) x->ch[0]->fa=x; if(m+1<=r) x->ch[1]->fa=x; x->pushup(); } void insert() { int pos=getint(), tot=getint(); node *x; build(x, 1, tot); node *fa=getrange(pos+1, pos); //注意getrange的地方 fa->setc(x, 0); root->ch[1]->pushup(); root->pushup(); } void cln(node *x) { if(x==null) return; cln(x->ch[0]); cln(x->ch[1]); s[++top]=x; } void dele() { int l=getint(), r=l+getint()-1; node *fa=getrange(l, r); cln(fa->ch[0]); fa->ch[0]=null; root->ch[1]->pushup(); root->pushup(); } void ask1() { int l=getint(), r=l+getint()-1; node *x=getrange(l, r)->ch[0]; printf("%d\n", x->sum); } void ask2() { node *x=getrange(1, root->s-2)->ch[0]; printf("%d\n", x->mx); } void fix1() { int l=getint(), r=l+getint()-1, c=getint(); node *x=getrange(l, r)->ch[0]; x->upd(c); root->ch[1]->pushup(); root->pushup(); } void fix2() { int l=getint(), r=l+getint()-1; node *x=getrange(l, r)->ch[0]; x->upd2(); root->ch[1]->pushup(); root->pushup(); } void init() { null=newnode(-oo); null->ch[0]=null->ch[1]=null->fa=null; null->s=null->sum=null->w=0; root=newnode(); node *c=newnode(); root->setc(c, 1); node *t; build(t, 1, n); root->ch[1]->setc(t, 0); root->ch[1]->pushup(); root->pushup(); } int main() { read(n); read(m); init(); char s[15]; for1(i, 1, m) { scanf("%s", s+1); char ch=s[3]; if(ch=='S') insert(); //INSERT else if(ch=='L') dele(); //DELETE else if(ch=='T') ask1(); //GET-SUM else if(ch=='X') ask2(); //MAX-SUM else if(ch=='K') fix1(); //MAKE-SAME else if(ch=='V') fix2(); //REVERSE //P(root, 1); } return 0; }
Description
Input
输入文件的第1行包含两个数N和M,N表示初始时数列中数的个数,M表示要进行的操作数目。第2行包含N个数字,描述初始时的数列。以下M行,每行一条命令,格式参见问题描述中的表格。
Output
对于输入数据中的GET-SUM和MAX-SUM操作,向输出文件依次打印结果,每个答案(数字)占一行。
Sample Input
2 -6 3 5 1 -5 -3 6 3
GET-SUM 5 4
MAX-SUM
INSERT 8 3 -5 7 2
DELETE 12 1
MAKE-SAME 3 3 2
REVERSE 3 6
GET-SUM 5 4
MAX-SUM
Sample Output
10
1
10
HINT
Source