线段维护最大连续子序列和(4836: Can you answer on these queries III)
什么是最大子段和?顾名思义,在一个序列中,找到一段,将这一段元素相加使得结果最大.
其实可以通过递推来在O(n)的时间复杂度内求出结果,也就是判断一个数前面一个数能否对最大子段和作贡献.如果前面那个数>0,那么就将它加到这一个数里,每次正在操作的数进行取max.但是这样不能在线查询(虽然这个题也不需要在线),所以我们用线段树的方法来维护每个区间的最大子段和.
线段树中记录几个变量:ls记录从区间左端点开始向右延伸能得到的最大子段和,rs记录从右端点开始向左延伸能得到的最大子段和,ss记录区间的最大子段和(不管是从区间中哪个位置开始),sum记录区间和.
我们将正在合并的区间节点编号叫root,它的左端点为l,右端点为r
那在合并ls的时候只存在这样几种情况:
1. root左端点包含的最大子段的右端点延伸到了右儿子
1. root左端点包含的最大子段的右端点仍然在左儿子的范围内
合并rs也是同理.
然后考虑如何合并ss,root的包含的最大子段的左端点叫x,右端点叫y,那么只有这样几种情况:
1. x==l,mid+1<=y<r
1. x==l,y==r
1. l<x<=mid,mid+1<=y<r
1. l<x<=mid,y==r
1. l<x<y<=mid
1. mid+1<=x<y<=r
整理一下式子也就是这样:
ls[l,r]=max(ls[l,mid],sum[l,mid]+ls[mid+1,r]) ls[l,r]=max(ls[l,mid],sum[l,mid]+ls[mid+1,r])
rs[l,r]=max(rs[mid+1,r],sum[mid+1,r]+rs[l,mid]) rs[l,r]=max(rs[mid+1,r],sum[mid+1,r]+rs[l,mid])
ss[l,r]=max(ss[l,mid],ss[mid+1,r],ls[mid+1,r]+rs[l,mid+1]) ss[l,r]=max(ss[l,mid],ss[mid+1,r],ls[mid+1,r]+rs[l,mid+1])
那么我们直接对这些情况进行讨论,下面看代码注释
#include<cstdio> #include<algorithm> #include<cstring> using namespace std; const int nil=-0x3fffffff; const int maxn=(50000+10)<<2; int a[50000+10]; struct my{ int lmax,rmax,maxx,sum; }tree[maxn]; void get(int rt){ tree[rt].sum=tree[rt<<1].sum+tree[rt<<1|1].sum; tree[rt].lmax=max(tree[rt<<1].lmax,tree[rt<<1].sum+tree[rt<<1|1].lmax); tree[rt].rmax=max(tree[rt<<1|1].rmax,tree[rt<<1|1].sum+tree[rt<<1].rmax); tree[rt].maxx=max(max(tree[rt<<1].maxx,tree[rt<<1|1].maxx),tree[rt<<1|1].lmax+tree[rt<<1].rmax); } void build(int l,int r,int rt){ if(l==r){ tree[rt].sum=tree[rt].lmax=tree[rt].maxx=tree[rt].rmax=a[l]; return ; } int mid=(l+r)>>1; build(l,mid,rt<<1); build(mid+1,r,rt<<1|1); get(rt); } void change(int L,int l,int r,int rt,int c){ if(l==r){ tree[rt].sum=tree[rt].lmax=tree[rt].maxx=tree[rt].rmax=c; return ; } int mid=(l+r)>>1; if(L<=mid) change(L,l,mid,rt<<1,c); else change(L,mid+1,r,rt<<1|1,c); get(rt); } my ask(int l,int r,int rt,int L,int R){ if(l>=L&&r<=R){ return tree[rt]; } my a,b,c; a.lmax=a.maxx=a.rmax=a.sum=nil; b.lmax=b.maxx=b.rmax=b.sum=nil; c.sum=0; int mid=(l+r)>>1; if(L<=mid) { a=ask(l,mid,rt<<1,L,R); c.sum+=a.sum; } if(R>mid) { b=ask(mid+1,r,rt<<1|1,L,R); c.sum+=b.sum; } c.maxx=max(max(a.maxx,b.maxx),a.rmax+b.lmax); c.lmax=max(a.lmax,b.lmax+a.sum); c.rmax=max(b.rmax,b.sum+a.rmax); return c; } int main(){ int n,m; scanf("%d",&n); for(int i=1;i<=n;i++){ scanf("%d",&a[i]); } build(1,n,1); scanf("%d",&m); int op,l,r; for (int i=1;i<=m;i++){ scanf("%d%d%d",&op,&l,&r); if(op==1) printf("%d\n",ask(1,n,1,l,r).maxx); else change(l,1,n,1,r); } return 0; }