HYSBZ 1500 维修数列(伸展树模板)
题意:
题解:典型伸展树的题,比较全面。
我理解的伸展树:
1 伸展操作:就是旋转,因为我们只需保证二叉树中序遍历的结果不变,所以我们可以旋转来保持树的平衡,且旋转有左旋与右旋。通过这种方式保证不会让树一直退化从而超时。虽然一次旋转的代价比较高,但是可以证明:每次操作都旋转(关键),则时间复杂度为O(n*log2 n)
2 更新:每个节点都可以存一些信息,并模拟线段树进行区间操作。父节点的信息是两个孩子节点加当前父节点的信息的总和。因为是可旋转的搜索二叉树,所以每次处理都需要注意上更新或下更新
3 注意:一般需要先开两个哨兵节点,一个作为开头,有个作为结尾,这样可以避免一些边界的讨论问题
#include<queue> #include<cstdio> #include<cstring> #include<algorithm> using namespace std; #define dir(a,b) (a>>b) #define ssplay(rt,x) (splay[rt].chd[x]) typedef long long ll; const int Max=6e5+7; const int Inf=1<<28; int num[Max],memp[Max],tot2,tot,root;//存值 内存池 内存池的值 总值 根节点 struct node { int file,sam;//翻转 是否修改 int chd[2],fat; int sizee,sum,val;//总个数 总大小 值 int lmax,rmax,mmax;//区间合并三变量 } splay[Max]; void Treaval(int rt) { if(rt) { Treaval(splay[rt].chd[0]); printf("rt=%2d lchd=%2d rchd=%2d sum=%2d size=%2d val=%2d lmax=%2d rmax=%2d mmax=%2d\n",rt,splay[rt].chd[0],splay[rt].chd[1], splay[rt].sum,splay[rt].sizee,splay[rt].val,splay[rt].lmax,splay[rt].rmax,splay[rt].mmax); Treaval(splay[rt].chd[1]); } return; } void debug() { printf("root=%d\n",root); Treaval(root); return; } inline void NewNode(int &rt,int fa,int va)//建立新节点 { if(tot2)//删除后的内存池可以再次利用 rt=memp[tot2--]; else rt=++tot; splay[rt].file=splay[rt].sam=0; splay[rt].chd[0]=splay[rt].chd[1]=0; splay[rt].val=splay[rt].lmax=splay[rt].rmax=splay[rt].mmax=splay[rt].sum=va; splay[rt].fat=fa; return; } inline int nmax(int a,int b) { return a>b?a:b; } inline void PushUp(int rt)//上更新(类似区间合并) { int lchd=splay[rt].chd[0],rchd=splay[rt].chd[1]; splay[rt].sizee=splay[lchd].sizee+splay[rchd].sizee+1; splay[rt].sum=splay[lchd].sum+splay[rchd].sum+splay[rt].val; splay[rt].mmax=nmax(nmax(splay[rt].val,splay[lchd].mmax),splay[rchd].mmax);//处理区间最大值 if(splay[rt].mmax>0) splay[rt].mmax=nmax(splay[rt].mmax,nmax(splay[lchd].rmax,0)+nmax(splay[rchd].lmax,0)+splay[rt].val);//val为关键 splay[rt].lmax=nmax(splay[lchd].lmax,splay[lchd].sum+splay[rt].val); splay[rt].lmax=nmax(splay[rt].lmax,splay[lchd].sum+splay[rt].val+splay[rchd].lmax); splay[rt].rmax=nmax(splay[rchd].rmax,splay[rchd].sum+splay[rt].val); splay[rt].rmax=nmax(splay[rt].rmax,splay[rchd].sum+splay[rt].val+splay[lchd].rmax); return; } inline void Swap(int &a,int &b) { int t=a; a=b; b=t; return; } inline void fson(int rt)//翻转 { if(!rt) return; Swap(splay[rt].chd[0],splay[rt].chd[1]);//孩子交换就好 Swap(splay[rt].lmax,splay[rt].rmax);//此位置的左右max需交换 splay[rt].file^=1;//此处修改与否只与父节点flie有关,与此处的file无关 } inline void sson(int rt,int va)//修改成va { if(!rt) return; splay[rt].val=va; splay[rt].sum=va*splay[rt].sizee; splay[rt].lmax=splay[rt].rmax=splay[rt].mmax=nmax(splay[rt].sum,va); splay[rt].sam=1; } inline void PushDown(int rt)//下更新(处理翻转与改变值) { if(splay[rt].file) { fson(splay[rt].chd[0]); fson(splay[rt].chd[1]); splay[rt].file=0; } if(splay[rt].sam) { sson(splay[rt].chd[0],splay[rt].val); sson(splay[rt].chd[1],splay[rt].val); splay[rt].sam=0; } return; } inline void Rotate(int rt,int kind)//**zig或者zag** { int y=splay[rt].fat; PushDown(y); PushDown(rt); splay[y].chd[kind^1]=splay[rt].chd[kind]; splay[ssplay(rt,kind)].fat=y; if(splay[y].fat)//不是一个zig后者zag splay[splay[y].fat].chd[ssplay(splay[y].fat,1)==y]=rt;//y父节点的(y的左右)孩子 splay[rt].fat=splay[y].fat; splay[rt].chd[kind]=y; splay[y].fat=rt; PushUp(y); return; } inline void Splay(int rt,int goal)//**关键的伸展操作(双旋)** { PushDown(rt); while(splay[rt].fat!=goal) { int y=splay[rt].fat; if(splay[y].fat==goal)//一次zig/zag { Rotate(rt,splay[y].chd[0]==rt);//rt是否为左孩子 } else { int kind=(splay[splay[y].fat].chd[0]==y?1:0);//y是否为左孩子 if(splay[y].chd[kind]==rt)//左孩子的右孩子或者右孩子的左孩子 { Rotate(rt,kind^1); Rotate(rt,kind); } else { Rotate(y,kind); Rotate(rt,kind); } } } PushUp(rt); if(!goal) root=rt;//更新根节点 return; } inline void Rotateto(int pos,int goal)//**得到第pos个数,并且进行伸展** { int rt=root; PushDown(rt); while(splay[ssplay(rt,0)].sizee!=pos) { if(splay[ssplay(rt,0)].sizee>pos) rt=splay[rt].chd[0]; else { pos-=(splay[ssplay(rt,0)].sizee+1); rt=splay[rt].chd[1]; } PushDown(rt); } Splay(rt,goal); return; } void Create(int sta,int enn,int &rt,int fa)//建树与添树 { if(sta>enn) return; int mid=dir(sta+enn,1); NewNode(rt,fa,num[mid]); Create(sta,mid-1,splay[rt].chd[0],rt); Create(mid+1,enn,splay[rt].chd[1],rt); PushUp(rt);//建树与添树时上更新 return; } void Init(int n)//初始化 { for(int i=0; i<n; ++i) scanf("%d",&num[i]); splay[0].lmax=splay[0].rmax=splay[0].mmax=-Inf;//可能全为负数 splay[0].chd[0]=splay[0].chd[1]=splay[0].fat=0;//建立哨兵,避免特判 splay[0].val=splay[0].sizee=splay[0].sum=tot2=tot=root=0; splay[0].sam=splay[0].file=0; NewNode(root,0,0);//建立两个哨兵 NewNode(splay[root].chd[1],root,0); Create(0,n-1,splay[ssplay(root,1)].chd[0],splay[root].chd[1]); PushUp(splay[root].chd[1]);//与建树一起的上更新 PushUp(root); return; } void Insert(int pos,int dig)//在pos与pos+1之间添加 { Rotateto(pos,0);//旋转pos位置的值成0的孩子 Rotateto(pos+1,root); Create(0,dig-1,splay[ssplay(root,1)].chd[0],splay[root].chd[1]);//建立 PushUp(splay[root].chd[1]);//更新 PushUp(root);//下面很多函数都是五行中仅仅修改第三行 return; } void Erase(int rt)//回收空间 { if(!rt) return; memp[++tot2]=rt; Erase(splay[rt].chd[0]); Erase(splay[rt].chd[1]); return; } void Delete(int pos,int dig)//删除pos后面dig个 { Rotateto(pos-1,0); Rotateto(pos+dig,root); Erase(splay[ssplay(root,1)].chd[0]);//关键位置 splay[ssplay(root,1)].chd[0]=0; PushUp(splay[root].chd[1]); PushUp(root); return; } void Make_same(int pos,int dig,int fix)//修改pos后面dig和为fix { Rotateto(pos-1,0); Rotateto(pos+dig,root); sson(splay[ssplay(root,1)].chd[0],fix); PushUp(splay[root].chd[1]); PushUp(root); return; } void Reverse(int pos,int dig)//翻转pos后面dig个 { Rotateto(pos-1,0); Rotateto(pos+dig,root); fson(splay[ssplay(root,1)].chd[0]); PushUp(splay[root].chd[1]); PushUp(root); return; } int GetSum(int pos,int dig)//计算pos后面dig个数的和 { Rotateto(pos-1,0); Rotateto(pos+dig,root); return splay[ssplay(ssplay(root,1),0)].sum; } int GetMaxsum(int pos,int dig)//区间最值 { Rotateto(pos-1,0); Rotateto(pos+dig,root); return splay[ssplay(ssplay(root,1),0)].mmax;//注意每次都需要splay一下,保证时间复杂度 } int main() { int n,m; char str[15]; int pos,dig,fix; while(~scanf("%d %d",&n,&m)) { Init(n); while(m--) { //debug(); scanf("%s",str); if(!strcmp(str,"INSERT"))//添加一段数 { scanf("%d %d",&pos,&dig); for(int i=0; i<dig; ++i) scanf("%d",&num[i]); Insert(pos,dig); } else if(!strcmp(str,"DELETE"))//删除一段数 { scanf("%d %d",&pos,&dig); Delete(pos,dig); } else if(!strcmp(str,"MAKE-SAME"))//修改 { scanf("%d %d %d",&pos,&dig,&fix); Make_same(pos,dig,fix); } else if(!strcmp(str,"REVERSE"))//翻转 { scanf("%d %d",&pos,&dig); Reverse(pos,dig); } else if(!strcmp(str,"GET-SUM"))//求和 { scanf("%d %d",&pos,&dig); printf("%d\n",GetSum(pos,dig)); } else//最大子序列 { printf("%d\n",GetMaxsum(1,splay[root].sizee-2));//注意有两个哨兵 } } } return 0; }