Splay大法好
新手推荐阅读:splay详解(一),Splay入门解析【保证让你看不懂(滑稽)】
打算记点关于 Splay 的笔记
splay嘛,本质上是一棵BST(即二叉查找树)。这棵树上的每一个节点的左孩子都比它小,右孩子都比它大,也就是说这棵树需要维护中序遍历。
【核心操作】
splay(x,y):把点 x 旋转到点 y 下面。
注意:当 x, y, z 在一条直线上时,先转 y 再转 x,否则先转 x 然后再转一遍 x。
void splay(int x,int o) {
if(!o) root=x;//更新根节点
while(fa(x)!=o) {
int y=fa(x),z=fa(y);
if(z!=o) {
if(chk(x)^chk(y)) rotate(x);//x,y,z不在一条直线上
else rotate(y); //x,y,z在一条直线上
}
rotate(x);
}
}
rotate (x):改变三对节点的父子关系,具体看图
void rotate(int x) {
int k=chk(x),y=fa(x),z=fa(y),w=a[x].ch[k^1];
a[y].ch[k]=w;fa(w)=y;
a[z].ch[chk(y)]=x;fa(x)=z;
a[x].ch[k^1]=y;fa(y)=x;//顺序不可以随意改变!
pushup(y);pushup(x);
}
树上的每一个节点不仅代表了原序列的一个值,还记录了一段序列(即它的子树)的相关信息,因此操作中要维护每个节点的 size,tag 等等。
struct hh{
int ch[2],fa,rev,val,siz;//用ch[0]表示左孩子,ch[1]表示右孩子
}a[N];
建树前先在首尾加两个值分别为 inf(极大) 和 -inf(极小)的节点以免出现莫名错误。
b[1]=-inf;b[n+2]=inf;
for(R i=1;i<=n;++i) b[i+1]=i;
root=build(1,n+2,0);
chk (x):查询 x 是 a[x].fa 的左孩子还是右孩子
int chk(int x) {return x==rs(fa(x));}//如果是右孩子就返回1,否则返回0
向上和向下维护信息
void pushup(int x) {a[x].siz=a[ls(x)].siz+a[rs(x)].siz+1;}
void pushdown(int x) {
if(!a[x].rev) return ;
swap(ls(x),rs(x));
a[ls(x)].rev^=1;a[rs(x)].rev^=1;
a[x].rev=0;
}
建树
int build(int l,int r,int f) {
if(l>r) return 0;
int mid=l+r>>1,id=++num;
fa(id)=f;a[id].siz=1;a[id].val=b[mid];
ls(id)=build(l,mid-1,id);
rs(id)=build(mid+1,r,id);
pushup(id);
return id;
}
//a[id]不仅代表原序列中的b[mid],还代表[l,r]这段区间
找区间对应的节点
int find(int k,int x) {
pushdown(k);
int cnt=a[ls(k)].siz;
if(cnt+1==x) return k;
if(cnt>=x) return find(ls(k),x);
else return find(rs(k),x-cnt-1);
}
void work(int l,int r) {
int x=find(root,l-1),y=find(root,r+1);
splay(x,0);splay(y,x);
a[ls(y)].rev^=1;
}//work函数是翻转[l,r]区间,该区间对应节点即ls(y)
中序遍历输出整段序列
void print(int x) {
pushdown(x);
if(ls(x)) print(ls(x));
if(a[x].val!=inf&&a[x].val!=-inf) printf("%d ",a[x].val);
if(rs(x)) print(rs(x));
}
以洛谷P3391【模板】文艺平衡树为例:
这道题只有翻转操作,每个节点只要再维护一个 rev(代表是否翻转)
1 #include<bits/stdc++.h>
2 #define fa(x) a[x].fa
3 #define ls(x) a[x].ch[0]
4 #define rs(x) a[x].ch[1]
5 #define R register int
6
7 using namespace std;
8 const int mod=10000,N=1e5+5,inf=0x3f3f3f3f;
9
10 int read() {
11 int f=1;char ch;
12 while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
13 int res=ch-'0';
14 while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';
15 return f*res;
16 }
17
18 struct hh{
19 int ch[2],fa,rev,val,siz;
20 }a[N];
21 int b[N],n,m,num,root;
22
23 int chk(int x) {return x==rs(fa(x));}
24 void pushup(int x) {a[x].siz=a[ls(x)].siz+a[rs(x)].siz+1;}
25 void pushdown(int x) {
26 if(!a[x].rev) return ;
27 swap(ls(x),rs(x));
28 a[ls(x)].rev^=1;a[rs(x)].rev^=1;
29 a[x].rev=0;
30 }
31
32 int build(int l,int r,int f) {
33 if(l>r) return 0;
34 int mid=l+r>>1,id=++num;
35 fa(id)=f;a[id].siz=1;a[id].val=b[mid];
36 ls(id)=build(l,mid-1,id);
37 rs(id)=build(mid+1,r,id);
38 pushup(id);
39 return id;
40 }
41
42 int find(int k,int x) {
43 pushdown(k);
44 int cnt=a[ls(k)].siz;
45 if(cnt+1==x) return k;
46 if(cnt>=x) return find(ls(k),x);
47 else return find(rs(k),x-cnt-1);
48 }
49
50 void rotate(int x) {
51 int k=chk(x),y=fa(x),z=fa(y),w=a[x].ch[k^1];
52 a[y].ch[k]=w;fa(w)=y;
53 a[z].ch[chk(y)]=x;fa(x)=z;
54 a[x].ch[k^1]=y;fa(y)=x;
55 pushup(y);pushup(x);
56 }
57
58 void splay(int x,int o) {
59 if(!o) root=x;
60 while(fa(x)!=o) {
61 int y=fa(x),z=fa(y);
62 if(z!=o) {
63 if(chk(x)^chk(y)) rotate(x);
64 else rotate(y);
65 }
66 rotate(x);
67 }
68 }
69
70 void work(int l,int r) {
71 int x=find(root,l-1),y=find(root,r+1);
72 splay(x,0);splay(y,x);
73 a[ls(y)].rev^=1;
74 }
75
76 void print(int x) {
77 pushdown(x);
78 if(ls(x)) print(ls(x));
79 if(a[x].val!=inf&&a[x].val!=-inf) printf("%d ",a[x].val);
80 if(rs(x)) print(rs(x));
81 }
82
83 int main() {
84 n=read();m=read();
85 b[1]=-inf;b[n+2]=inf;
86 for(R i=1;i<=n;++i) b[i+1]=i;
87 root=build(1,n+2,0);
88 while(m--) {
89 int l=read()+1,r=read()+1;
90 work(l,r);
91 }
92 print(root);
93 return 0;
94 }
(去年 noip 之后码风变了很多,个人觉得更美观了,还加上了宏定义什么的方便嵌套
接下来我尝试了一道相当毒瘤的题:洛谷P2042 [NOI2005] 维护数列
操作一:在序列的第 pos 和 pos+1 个数字之间插入 tot 个数字
注意,题目中给出的 pos 在我们的操作中对应的其实是 pos+1,因为序列首端多加入了一个 -inf
根据 splay 的性质,我们先把 pos 移至根节点,再把 pos+1 移到 pos 下面,这样 pos+1 的左孩子就是我们插入序列的位置(因为这个位置是 pos+1 的左孩子,同时也是 pos 的右孩子的子树中的一个,即它比 pos+1 小,但比 pos 大)
新插入的序列也先建成一棵 splay 再插入
void insert(int pos,int tot) {
for(R i=1;i<=tot;++i) b[i]=read();
int id=build(1,tot,0);
int x=find(root,pos),y=find(root,pos+1);
splay(x,0);splay(y,x);
ls(y)=id;fa(id)=y;
pushup(y);pushup(x);
}
操作二:删除 [pos, pos+tot-1] 这个区间
如果我们把 pos-1 移至根节点,再把 pos+tot 移到 pos 下面,这样 pos+1 的左孩子就是序列中的 [pos, pos+tot-1],那么要将这个序列删除,把 pos+1 的左孩子记为空即可。但因为本题直接这么做会爆空间,那就把已经删除了的节点先存储起来,以备后用,避免开太多空间。我在代码里采用的是压栈的方法。
void recycle(int x) {
if(!x) return ;
st[++top]=x;
recycle(ls(x)),recycle(rs(x));
}//回收节点x
void del(int pos,int tot) {
int x=find(root,pos-1),y=find(root,pos+tot);
splay(x,0);splay(y,x);
recycle(ls(y));ls(y)=0;
pushup(y);pushup(x);
}
操作三:将 [pos, pos+tot-1] 这个区间的值全部改为 c
像上面那样先找到 [pos, pos+tot-1],然后给代表这个区间的节点都加上和赋值有关的懒标记即可(pushup 别忘了!
void Tag(int x,int c) {
if(!x) return;
a[x].val=c;
a[x].sum=c*a[x].siz;
a[x].mx=max(a[x].sum,c);
a[x].lm=a[x].rm=max(0,a[x].mx);//a[x].lm和a[x].rm可以为0,因为a[x].mx不一定要包含ls(x)或rs(x)
a[x].tag=1;
}
void make_same(int pos,int tot,int c) {
int x=find(root,pos-1),y=find(root,pos+tot);
splay(x,0);splay(y,x);
Tag(ls(y),c);
pushup(y);pushup(x);
}
操作四:翻转 [pos, pos+tot-1] 这个区间
同样先找到该区间,给代表节点打上和翻转有关的懒标记
void Rev(int x) {
if(!x) return ;
swap(a[x].lm,a[x].rm);//注意此处交换lm和rm!
swap(ls(x),rs(x));
a[x].rev^=1;
}
void rev(int pos,int tot) {
int x=find(root,pos-1),y=find(root,pos+tot);
splay(x,0);splay(y,x);
Rev(ls(y));
pushup(y);pushup(x);
}
操作五:求 [pos, pos+tot-1] 这个区间的和
建树时让每个节点维护所对应区间的和,查询时只要找到 [pos, pos+tot-1] 这个区间的代表节点就可以 O(1) 输出
int query(int pos,int tot) {
int x=find(root,pos-1),y=find(root,pos+tot);
splay(x,0);splay(y,x);
return a[ls(y)].sum;
}
操作六:求整个序列的最大子段和
这个子问题的做法类似于洛谷P4513 小白逛公园。建树时让每个节点维护所对应区间的最大子段和,推出转移方程即可。
int max_sum() {
int x=find(root,1),y=find(root,n+2);
splay(x,0);splay(y,x);
return a[ls(y)].mx;
}
void pushup(int x) {
a[x].siz=a[ls(x)].siz+a[rs(x)].siz+1;
a[x].sum=a[ls(x)].sum+a[rs(x)].sum+a[x].val;
a[x].lm=max(a[ls(x)].lm,a[ls(x)].sum+a[x].val+a[rs(x)].lm);
a[x].rm=max(a[rs(x)].rm,a[rs(x)].sum+a[x].val+a[ls(x)].rm);
a[x].mx=max(max(a[ls(x)].mx,a[rs(x)].mx),a[ls(x)].rm+a[x].val+a[rs(x)].lm);
}
int build(int l,int r,int f) {
if(l>r) return 0;
int mid=l+r>>1,id=get();
a[id].val=b[mid];fa(id)=f;a[id].siz=1;
a[id].lm=a[id].rm=max(0,a[id].val);
a[id].mx=a[id].sum=a[id].val;
a[id].rev=a[id].tag=0;
ls(id)=build(l,mid-1,id);
rs(id)=build(mid+1,r,id);
pushup(id);
return id;
}
这道题毒瘤之处在于有很多坑点,洛谷讨论帖一堆“告诫后人”,我也调了好久,感谢喻队帮忙看代码orz
详情见代码吧
1 #include<bits/stdc++.h>
2 #define R register int
3 #define ls(x) a[x].ch[0]
4 #define rs(x) a[x].ch[1]
5 #define fa(x) a[x].fa
6
7 using namespace std;
8 const int N=5e5+5,inf=1e9;
9
10 int read() {
11 int f=1;char ch;
12 while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
13 int res=ch-'0';
14 while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';
15 return res*f;
16 }
17
18 int n,m,st[N],top,root,b[N];
19 struct hh{
20 int ch[2],lm,rm,mx,sum,siz,val,fa,tag,rev;
21 }a[N];
22
23 int chk(int x) {return x==rs(fa(x));}
24 int get() {return st[top--];}
25
26 void Rev(int x) {
27 if(!x) return ;
28 swap(a[x].lm,a[x].rm);//注意此处交换lm和rm!
29 swap(ls(x),rs(x));
30 a[x].rev^=1;
31 }
32
33 void Tag(int x,int c) {
34 if(!x) return;
35 a[x].val=c;
36 a[x].sum=c*a[x].siz;
37 a[x].mx=max(a[x].sum,c);
38 a[x].lm=a[x].rm=max(0,a[x].mx);//a[x].lm和a[x].rm可以为0,因为a[x].mx不一定要包含ls(x)或rs(x)
39 a[x].tag=1;
40 }
41
42 void pushup(int x) {
43 a[x].siz=a[ls(x)].siz+a[rs(x)].siz+1;
44 a[x].sum=a[ls(x)].sum+a[rs(x)].sum+a[x].val;
45 a[x].lm=max(a[ls(x)].lm,a[ls(x)].sum+a[x].val+a[rs(x)].lm);
46 a[x].rm=max(a[rs(x)].rm,a[rs(x)].sum+a[x].val+a[ls(x)].rm);
47 a[x].mx=max(max(a[ls(x)].mx,a[rs(x)].mx),a[ls(x)].rm+a[x].val+a[rs(x)].lm);
48 }
49
50 void pushdown(int x) {
51 if(!x) return ;
52 if(a[x].tag) {
53 Tag(ls(x),a[x].val),Tag(rs(x),a[x].val);//也可以在这里先判断一下ls(x)、rs(x)是否为空
54 a[x].tag=a[x].rev=0;
55 }
56 if(a[x].rev) {
57 Rev(ls(x)),Rev(rs(x));//同上
58 a[x].rev=0;
59 }
60 }
61
62 void rotate(int x) {
63 int y=fa(x),z=fa(y),k=chk(x),w=a[x].ch[k^1];
64 a[y].ch[k]=w;fa(w)=y;
65 a[z].ch[chk(y)]=x;fa(x)=z;
66 a[x].ch[k^1]=y;fa(y)=x;
67 pushup(y);pushup(x);
68 }
69
70 void splay(int x,int o) {
71 if(!o) root=x;
72 while(fa(x)!=o) {
73 int y=fa(x),z=fa(y);
74 if(z!=o) {
75 if(chk(x)^chk(y)) rotate(x);
76 else rotate(y);
77 }
78 rotate(x);
79 }
80 }
81
82 int find(int x,int k) {
83 pushdown(x);//记得标记下传!
84 int cnt=a[ls(x)].siz;
85 if(cnt==k-1) return x;
86 if(cnt>=k) return find(ls(x),k);
87 else return find(rs(x),k-cnt-1);
88 }
89
90 int build(int l,int r,int f) {
91 if(l>r) return 0;
92 int mid=l+r>>1,id=get();
93 a[id].val=b[mid];fa(id)=f;a[id].siz=1;
94 a[id].lm=a[id].rm=max(0,a[id].val);
95 a[id].mx=a[id].sum=a[id].val;
96 a[id].rev=a[id].tag=0;
97 ls(id)=build(l,mid-1,id);
98 rs(id)=build(mid+1,r,id);
99 pushup(id);
100 return id;
101 }
102
103 int max_sum() {
104 int x=find(root,1),y=find(root,n+2);
105 splay(x,0);splay(y,x);
106 return a[ls(y)].mx;
107 }
108
109 void make_same(int pos,int tot,int c) {
110 int x=find(root,pos-1),y=find(root,pos+tot);
111 splay(x,0);splay(y,x);
112 Tag(ls(y),c);
113 pushup(y);pushup(x);
114 }
115
116 void insert(int pos,int tot) {
117 for(R i=1;i<=tot;++i) b[i]=read();
118 int id=build(1,tot,0);
119 int x=find(root,pos),y=find(root,pos+1);
120 splay(x,0);splay(y,x);
121 ls(y)=id;fa(id)=y;
122 pushup(y);pushup(x);
123 }
124
125 void recycle(int x) {
126 if(!x) return ;
127 st[++top]=x;
128 recycle(ls(x)),recycle(rs(x));
129 }
130
131 void del(int pos,int tot) {
132 int x=find(root,pos-1),y=find(root,pos+tot);
133 splay(x,0);splay(y,x);
134 recycle(ls(y));ls(y)=0;
135 pushup(y);pushup(x);
136 }
137
138 void rev(int pos,int tot) {
139 int x=find(root,pos-1),y=find(root,pos+tot);
140 splay(x,0);splay(y,x);
141 Rev(ls(y));
142 pushup(y);pushup(x);
143 }
144
145 int query(int pos,int tot) {
146 int x=find(root,pos-1),y=find(root,pos+tot);
147 splay(x,0);splay(y,x);
148 return a[ls(y)].sum;
149 }
150
151 int main() {
152 a[0].mx=-inf;//attention!
153 n=read(),m=read();
154 for(R i=1;i<N;i++) st[i]=i;top=N-1;//st用来回收节点
155
156 for(R i=1;i<=n;++i) b[i+1]=read();
157 b[1]=-inf,b[n+2]=inf;//attention!
158 root=build(1,n+2,0);
159 while(m--) {
160 char s[12];scanf("%s",s);
161 if(s[0]=='M') {
162 if(s[2]=='X') printf("%d\n",max_sum());
163 else {
164 int pos=read()+1,tot=read(),c=read();//pos需要+1
165 make_same(pos,tot,c);
166 }
167 }
168 else {
169 int pos=read()+1,tot=read();//pos需要+1
170 if(s[0]=='I') n+=tot,insert(pos,tot);//n需要更新,因为max_sum函数中会用到新的n
171 else if(s[0]=='D') n-=tot,del(pos,tot);//同上
172 else if(s[0]=='R') rev(pos,tot);
173 else if(s[0]=='G') printf("%d\n",query(pos,tot));
174 }
175 }
176 return 0;
177 }
1 #include<bits/stdc++.h>
2 #define R register int
3 #define ls(x) a[x].ch[0]
4 #define rs(x) a[x].ch[1]
5 #define fa(x) a[x].fa
6
7 using namespace std;
8 const int N=2100000,inf=0x3f3f3f3f;
9
10 int read() {
11 int f=1;char ch;
12 while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
13 int res=ch-'0';
14 while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';
15 return res*f;
16 }
17
18 int root,pos=1,num,n;
19 char b[N];
20 struct hh{
21 int ch[2],fa,siz;
22 char s;
23 }a[N];
24
25 int chk(int x) {return x==rs(fa(x));}
26 void pushup(int x) {a[x].siz=a[ls(x)].siz+a[rs(x)].siz+1;}
27
28 void print(int x) {
29 if(ls(x)) print(ls(x));
30 printf("%c",a[x].s);
31 if(rs(x)) print(rs(x));
32 pushup(x);
33 }
34
35 int build(int l,int r,int f) {
36 if(l>r) return 0;
37 int mid=l+r>>1,id=++num;
38 a[id].s=b[mid];a[id].fa=f;a[id].siz=1;
39 ls(id)=build(l,mid-1,id);
40 rs(id)=build(mid+1,r,id);
41 pushup(id);
42 return id;
43 }
44
45 int find(int x,int k) {
46 int cnt=a[ls(x)].siz;
47 if(cnt+1==k) return x;
48 if(cnt>=k) return find(ls(x),k);
49 else return find(rs(x),k-cnt-1);
50 }
51
52 void rotate(int x) {
53 int y=fa(x),z=fa(y),k=chk(x),w=a[x].ch[k^1];
54 a[z].ch[chk(y)]=x;fa(x)=z;
55 a[y].ch[k]=w;fa(w)=y;
56 a[x].ch[k^1]=y;fa(y)=x;
57 pushup(y);pushup(x);
58 }
59
60 void splay(int x,int o) {
61 if(!o) root=x;
62 while(fa(x)!=o) {
63 int y=fa(x),z=fa(y);
64 if(z!=o) {
65 if(chk(x)^chk(y)) rotate(x);
66 else rotate(y);
67 }
68 rotate(x);
69 }
70 }
71
72 void insert(int cnt) {
73 for(R i=1;i<=cnt;++i) {
74 b[i]=getchar();
75 if(b[i]<32||b[i]>126) i--;
76 }
77 int id=build(1,cnt,0);
78 int x=find(root,pos),y=find(root,pos+1);
79 splay(x,0);splay(y,x);
80 ls(y)=id;fa(id)=y;
81 pushup(y);pushup(x);
82 }
83
84 void del(int cnt) {
85 int x=find(root,pos),y=find(root,pos+cnt+1);
86 splay(x,0);splay(y,x);
87 ls(y)=0;
88 pushup(y);pushup(x);
89 }
90
91 void get(int cnt) {
92 int x=find(root,pos),y=find(root,pos+cnt+1);
93 splay(x,0);splay(y,x);
94 print(ls(y));putchar('\n');
95 }
96
97 int main() {
98 b[0]=b[1]=b[2]=' ';
99 root=build(1,2,0);
100 n=2;
101 int t=read();
102 while(t--) {
103 char s[10];scanf("%s",s);
104 if(s[0]=='P') {if(pos) pos--;}
105 else if(s[0]=='N') pos++;
106 else {
107 int cnt=read();
108 if(s[0]=='M') pos=cnt+1;
109 else if(s[0]=='I') n+=cnt,insert(cnt);
110 else if(s[0]=='D') cnt=min(n-pos,cnt),n-=cnt,del(cnt);
111 else if(s[0]=='G') cnt=min(n-pos,cnt),get(cnt);
112 }
113 }
114 return 0;
115 }
1 #include<bits/stdc++.h>
2 #define R register int
3 #define ls(x) a[x].ch[0]
4 #define rs(x) a[x].ch[1]
5 #define fa(x) a[x].fa
6
7 using namespace std;
8 const int N=2100000,inf=0x3f3f3f3f;
9
10 int read() {
11 int f=1;char ch;
12 while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
13 int res=ch-'0';
14 while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';
15 return res*f;
16 }
17
18 int root,pos=1,num,n;
19 char b[N];
20 struct hh{
21 int ch[2],fa,siz,rev;
22 char s;
23 }a[N];
24
25 int chk(int x) {return x==rs(fa(x));}
26 void pushup(int x) {a[x].siz=a[ls(x)].siz+a[rs(x)].siz+1;}
27
28 void Rev(int x) {
29 a[x].rev^=1,swap(ls(x),rs(x));
30 }
31
32 void pushdown(int x) {
33 if(!a[x].rev) return;
34 if(ls(x)) Rev(ls(x));
35 if(rs(x)) Rev(rs(x));
36 a[x].rev=0;
37 }
38
39 void print(int x) {
40 pushdown(x);
41 if(ls(x)) print(ls(x));
42 printf("%c",a[x].s);
43 if(rs(x)) print(rs(x));
44 pushup(x);
45 }
46
47 int build(int l,int r,int f) {
48 if(l>r) return 0;
49 int mid=l+r>>1,id=++num;
50 a[id].s=b[mid];a[id].fa=f;a[id].siz=1,a[id].rev=0;
51 ls(id)=build(l,mid-1,id);
52 rs(id)=build(mid+1,r,id);
53 pushup(id);
54 return id;
55 }
56
57 int find(int x,int k) {
58 pushdown(x);
59 int cnt=a[ls(x)].siz;
60 if(cnt+1==k) return x;
61 if(cnt>=k) return find(ls(x),k);
62 else return find(rs(x),k-cnt-1);
63 }
64
65 void rotate(int x) {
66 int y=fa(x),z=fa(y),k=chk(x),w=a[x].ch[k^1];
67 a[z].ch[chk(y)]=x;fa(x)=z;
68 a[y].ch[k]=w;fa(w)=y;
69 a[x].ch[k^1]=y;fa(y)=x;
70 pushup(y);pushup(x);
71 }
72
73 void splay(int x,int o) {
74 if(!o) root=x;
75 while(fa(x)!=o) {
76 int y=fa(x),z=fa(y);
77 if(z!=o) {
78 if(chk(x)^chk(y)) rotate(x);
79 else rotate(y);
80 }
81 rotate(x);
82 }
83 }
84
85 void insert(int cnt) {
86 for(R i=1;i<=cnt;++i) {
87 b[i]=getchar();
88 if(b[i]<32||b[i]>126) {
89 if(i==cnt) b[i]=' ';
90 else i--;
91 }
92 }
93 int id=build(1,cnt,0);
94 int x=find(root,pos),y=find(root,pos+1);
95 splay(x,0);splay(y,x);
96 ls(y)=id;fa(id)=y;
97 pushup(y);pushup(x);
98 }
99
100 void del(int cnt) {
101 int x=find(root,pos),y=find(root,pos+cnt+1);
102 splay(x,0);splay(y,x);
103 ls(y)=0;
104 pushup(y);pushup(x);
105 }
106
107 void get(int cnt) {
108 int x=find(root,pos),y=find(root,pos+cnt+1);
109 splay(x,0);splay(y,x);
110 print(ls(y));putchar('\n');
111 }
112
113 void reverse(int cnt) {
114 int x=find(root,pos),y=find(root,pos+cnt+1);
115 splay(x,0);splay(y,x);
116 Rev(ls(y));
117 pushup(y);pushup(x);
118 }
119
120 int main() {
121 b[1]=b[2]=' ';
122 root=build(1,2,0);
123 n=2;
124 int t=read();
125 while(t--) {
126 char s[10];scanf("%s",s);
127 if(s[0]=='P') {if(pos) pos--;}
128 else if(s[0]=='N') pos++;
129 else if(s[0]=='G') get(1);
130 else {
131 int cnt=read();
132 if(s[0]=='M') pos=cnt+1;
133 else if(s[0]=='I') n+=cnt,insert(cnt);
134 else if(s[0]=='D') cnt=min(n-pos,cnt),n-=cnt,del(cnt);
135 else if(s[0]=='R') reverse(cnt);
136 }
137 }
138 return 0;
139 }
洛谷P3215 [HNOI2011]括号修复 / [JSOI2011]括号序列
这道题调了好几天。。几个明显的错误都没看出来/扶额
最大的问题是区间赋值时要直接用所赋的值c更新a[x].tag,而不能直接在懒标记下传时用节点值a[x].val来更新x的左右儿子的值,因为之前的Inv操作会影响到a[x].val,进而影响下传的标记。
1 #include<bits/stdc++.h>
2 #define IL inline
3 #define R register int
4 #define ls(x) a[x].ch[0]
5 #define rs(x) a[x].ch[1]
6 #define fa(x) a[x].fa
7
8 using namespace std;
9 const int N=5e5+5,inf=0x3f3f3f3f;
10
11 int read() {
12 int f=1;char ch;
13 while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
14 int res=ch-'0';
15 while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';
16 return res*f;
17 }
18
19 char cha[N];
20 int n,q,root,num,b[N];
21 struct hh {
22 int qma,hma,qmi,hmi,fa,ch[2],siz,sum,val,rev,tag,inv;
23 }a[N];
24
25 IL int min(int x,int y){return x<y?x:y;}
26 IL int max(int x,int y){return x>y?x:y;}
27 int chk(int x) {return x==rs(fa(x));}
28
29 void Rev(int x) {
30 a[x].rev^=1;
31 swap(a[x].qma,a[x].hma);
32 swap(a[x].qmi,a[x].hmi);
33 swap(ls(x),rs(x));
34 }
35
36 void Tag(int x,int c) {
37 a[x].val=c;
38 a[x].sum=c*a[x].siz;
39 a[x].tag=c;a[x].rev=a[x].inv=0;
40 a[x].qma=a[x].hma=max(0,a[x].sum);
41 a[x].qmi=a[x].hmi=min(0,a[x].sum);
42 }
43
44 void Inv(int x) {
45 a[x].inv^=1;
46 a[x].sum=-a[x].sum;a[x].val=-a[x].val;
47 swap(a[x].qmi,a[x].qma);
48 a[x].qmi=-a[x].qmi;a[x].qma=-a[x].qma;
49 swap(a[x].hma,a[x].hmi);
50 a[x].hma=-a[x].hma;a[x].hmi=-a[x].hmi;
51 }
52
53 void pushup(int x) {
54 a[x].siz=1+a[ls(x)].siz+a[rs(x)].siz;
55 a[x].sum=a[x].val+a[ls(x)].sum+a[rs(x)].sum;
56 a[x].qma=max(a[ls(x)].qma,a[ls(x)].sum+a[x].val+a[rs(x)].qma);
57 a[x].qmi=min(a[ls(x)].qmi,a[ls(x)].sum+a[x].val+a[rs(x)].qmi);
58 a[x].hma=max(a[rs(x)].hma,a[rs(x)].sum+a[x].val+a[ls(x)].hma);
59 a[x].hmi=min(a[rs(x)].hmi,a[rs(x)].sum+a[x].val+a[ls(x)].hmi);
60 }
61
62 void pushdown(int x) {
63 if(a[x].tag) {
64 if(ls(x)) Tag(ls(x),a[x].tag);
65 if(rs(x)) Tag(rs(x),a[x].tag);
66 a[x].tag=0;
67 }
68 if(a[x].rev) {
69 if(ls(x)) Rev(ls(x));
70 if(rs(x)) Rev(rs(x));
71 a[x].rev=0;
72 }
73 if(a[x].inv) {
74 if(ls(x)) Inv(ls(x));
75 if(rs(x)) Inv(rs(x));
76 a[x].inv=0;
77 }
78 }
79
80 int build(int l,int r,int f) {
81 if(l>r) return 0;
82 int mid=l+r>>1,id=++num;
83 a[id].val=b[mid];
84 a[id].siz=1;fa(id)=f;a[id].sum=a[id].val;
85 a[id].rev=0;a[id].tag=0;a[id].inv=0;
86 a[id].qma=a[id].hma=max(0,a[id].sum);
87 a[id].qmi=a[id].hmi=min(0,a[id].sum);
88 ls(id)=build(l,mid-1,id);
89 rs(id)=build(mid+1,r,id);
90 pushup(id);
91 return id;
92 }
93
94 int find(int x,int k) {
95 pushdown(x);
96 int cnt=a[ls(x)].siz;
97 if(cnt==k-1) return x;
98 if(cnt>=k) return find(ls(x),k);
99 else return find(rs(x),k-cnt-1);
100 }
101
102 void rotate(int x) {
103 int y=fa(x),z=fa(y),k=chk(x),w=a[x].ch[k^1];
104 a[z].ch[chk(y)]=x;fa(x)=z;
105 a[x].ch[k^1]=y;fa(y)=x;
106 a[y].ch[k]=w;fa(w)=y;
107 pushup(y);pushup(x);
108 }
109
110 void splay(int x,int o) {
111 if(!o) root=x;
112 while(fa(x)!=o) {
113 int y=fa(x),z=fa(y);
114 if(z!=o) {
115 if(chk(x)^chk(y)) rotate(x);
116 else rotate(y);
117 }
118 rotate(x);
119 }
120 }
121
122 void replace(int l,int r,int c) {
123 int x=find(root,l-1),y=find(root,r+1);
124 splay(x,0);splay(y,x);
125 Tag(ls(y),c);
126 pushup(y);pushup(x);
127 }
128
129 void reverse(int l,int r) {
130 int x=find(root,l-1),y=find(root,r+1);
131 splay(x,0);splay(y,x);
132 Rev(ls(y));
133 pushup(y);pushup(x);
134 }
135
136 void invert(int l,int r) {
137 int x=find(root,l-1),y=find(root,r+1);
138 splay(x,0);splay(y,x);
139 Inv(ls(y));
140 pushup(y);pushup(x);
141 }
142
143 void query(int l,int r) {
144 int x=find(root,l-1),y=find(root,r+1);
145 splay(x,0);splay(y,x);
146 int ans=((-a[ls(y)].qmi+1)>>1)+((a[ls(y)].hma+1)>>1);
147 printf("%d\n",ans);
148 }
149
150 void print(int x) {
151 pushdown(x);
152 if(ls(x)) print(ls(x));
153 if(a[x].val==1) cout<<"(";
154 else if(a[x].val==-1) cout<<")";
155 if(rs(x)) print(rs(x));
156 }
157
158 //'(':1,')':-1
159 //hma/2+qmi/2
160 int main() {
161 n=read();q=read();
162 scanf("%s",cha+2);
163 for(R i=2;i<=n+1;++i)
164 if(cha[i]=='(') b[i]=1;
165 else b[i]=-1;
166 root=build(1,n+2,0);
167 while(q--) {
168 char s[8];int l,r;
169 scanf("%s",s);
170 l=read()+1,r=read()+1;
171 if(s[0]=='R') {
172 char c[2];scanf("%s",c);
173 int val=(c[0]=='('?1:-1);
174 replace(l,r,val);
175 }
176 else if(s[0]=='S') reverse(l,r);
177 else if(s[0]=='I') invert(l,r);
178 else query(l,r);
179 }
180 return 0;
181 }
看了一圈代码好像都没有和我一样的码风/doge
终于依靠自己把revolve函数写出来并且调出来了!
#include<bits/stdc++.h>
//#define int long long
#define IL inline
#define R register int
#define fa(x) a[x].fa
#define ls(x) a[x].ch[0]
#define rs(x) a[x].ch[1]
using namespace std;
const int N=1e6+5,inf=0x3f3f3f3f;
IL int read() {
int f=1;
char ch;
while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
int res=ch-'0';
while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';
return res*f;
}
int n,m,b[N],root,num;
struct hh {
int fa,rev,tag,mi,siz,ch[2],val,ad;
}a[N];
int chk(int x) {return x==rs(fa(x));}
void Rev(int x) {a[x].rev^=1;swap(ls(x),rs(x));}
void Add(int x,int d) {a[x].val+=d;a[x].ad+=d;a[x].mi+=d;}
void pushup(int x) {
a[x].siz=a[ls(x)].siz+a[rs(x)].siz+1;
a[x].mi=a[x].val;
if(ls(x)) a[x].mi=min(a[ls(x)].mi,a[x].mi);//注意判断左右孩子是否存在
if(rs(x)) a[x].mi=min(a[rs(x)].mi,a[x].mi);
}
void pushdown(int x) {
if(a[x].rev) {
if(ls(x)) Rev(ls(x));
if(rs(x)) Rev(rs(x));
a[x].rev=0;
}
if(a[x].ad) {
if(ls(x)) Add(ls(x),a[x].ad);
if(rs(x)) Add(rs(x),a[x].ad);
a[x].ad=0;
}
}
int find(int x,int k) {
pushdown(x);
int cnt=a[ls(x)].siz;
if(cnt==k-1) return x;
if(cnt>=k) return find(ls(x),k);
else return find(rs(x),k-cnt-1);
}
void rotate(int x) {
int y=fa(x),z=fa(y),k=chk(x),w=a[x].ch[k^1];
a[z].ch[chk(y)]=x;fa(x)=z;
a[y].ch[k]=w;fa(w)=y;
a[x].ch[k^1]=y;fa(y)=x;
pushup(y);pushup(x);
}
void splay(int x,int o) {
if(!o) root=x;
while(fa(x)!=o) {
int y=fa(x),z=fa(y);
if(z!=o) {
if(chk(x)^chk(y)) rotate(x);
else rotate(y);
}
rotate(x);
}
}
int build(int l,int r,int f) {
if(l>r) return 0;
int mid=l+r>>1,id=++num;
a[id].val=a[id].mi=b[mid];
a[id].siz=1;fa(id)=f;
a[id].rev=a[id].tag=0;
ls(id)=build(l,mid-1,id);
rs(id)=build(mid+1,r,id);
pushup(id);
return id;
}
void del(int k) {
int x=find(root,k-1),y=find(root,k+1);
splay(x,0);splay(y,x);
ls(y)=0;
pushup(y);pushup(x);
}
void insert(int k,int c) {
int x=find(root,k),y=find(root,k+1);
splay(x,0);splay(y,x);
int id=++num;
ls(y)=id;fa(id)=y;
a[id].val=a[id].mi=c;
a[id].siz=1;a[id].rev=a[id].tag=0;
pushup(y);pushup(x);
}
void reverse(int l,int r) {
int x=find(root,l-1),y=find(root,r+1);
splay(x,0);splay(y,x);
Rev(ls(y));
pushup(y);pushup(x);
}
void revolve(int l,int r,int t) {
t%=r-l+1;//优化
if(!t) return ;
int x=find(root,l-1),y=find(root,r-t+1);
splay(x,0);splay(y,x);
pushdown(x);pushdown(y);
int id1=ls(y);//id1即区间[l,r-t]的编号
ls(y)=0;//先删去这个区间
pushup(y);pushup(x);
x=find(root,r-t-a[id1].siz),y=find(root,r-a[id1].siz+1);//a[id1].siz写成r-t-l+1亦可
splay(x,0);splay(y,x);
int id2=ls(y);//id2即原区间[r-t+1,r-t]的编号,因为上面已经删去了编号为id1的区间[l,r-t],所以找区间编号的时候要减去id1区间的大小a[id1].siz
pushdown(x);pushdown(y);pushdown(id2);
while(rs(id2)) id2=rs(id2),pushdown(id2);//令id2为区间[r-t+1,r-t]最右端的点的编号
splay(id2,y);//将这个点移为y的左孩子
fa(id1)=id2;rs(id2)=id1;//把编号为id1的区间[l,r-t]插到区间[r-t+1,r-t]后面,即成为id2的右孩子
pushup(id2);pushup(y);pushup(x);
}
void mi(int l,int r) {
int x=find(root,l-1),y=find(root,r+1);
splay(x,0);splay(y,x);
printf("%d\n",a[ls(y)].mi);
}
void add(int l,int r,int d) {
int x=find(root,l-1),y=find(root,r+1);
splay(x,0);splay(y,x);
Add(ls(y),d);
pushup(y);pushup(x);
}
int main() {
n=read();
b[1]=-inf;b[n+2]=inf;
for(R i=2;i<=n+1;++i) b[i]=read();
root=build(1,n+2,0);
m=read();
while(m--) {
char s[8];scanf("%s",s);
int x=read()+1;
if(s[0]=='D') del(x);
else if(s[0]=='I') {
int p=read();
insert(x,p);
}
else {
int y=read()+1;
if(s[0]=='R') {
if(s[3]=='E') reverse(x,y);
else {
int t=read();
revolve(x,y,t);
}
}
else if(s[0]=='M') mi(x,y);
else {
int d=read();
add(x,y,d);
}
}
}
return 0;
}