poj 2750 经典线段树
题意:给出一数组,数组首尾是可以相接的,要求求出最大连续序列值,并且不可以包括所有元素值。
遇到求这种最大连续序列值,看了网上大牛的思想才知道一般分为两种情况:1、不同时包括两端点的情况,这时直接求整个序列中的最大连续序列值;2、同时包括两端点元素,这时求整个序列中的最小连续序列值,然后用总和减去它,就是所要求的值。
具体怎么求法呢?
假设我们将整个序列分成两个连续的序列a,b;与整个序列设成A。假如我们知道a,b序列各个的从左向右最大连续序列值lmax,从右往左的最大连续序列值rmax,和从左向右最小连续序列值lmin,从右往左的最小连续序列值rmin,和每个序列的最大连续序列值nmax和最小连续序列值nmin,和最大元素值max,和最小元素值min;
那么,就有:
A.nmax=max(a.nmax,b.max,a.rmax+b.lmax,);
A.lmax=max(a.sum+b.lmax,a.max);
A.rmax=max(b.sum+a.rmax,b.rmax);
A.lmin=min(a.sum+b.lmin,a.lmin);
A.rmin=min(b.sum+a.rmin,b.rmin);
A.min=min(a.min,b.min) ;
A.max=max(a.max,b.max) ;
现在的A.nmax并不是最终结果,因为还没考虑最大连续序列值存在两端的情况,
A.nmax=max(A.nmax,A.sum-(a.rmin+b.lmin))
这是不是最终的答案呢? 不是,因为没有考虑全为正数和全为负数的情况,所以这种情况下
A.nmax=A.namx-A.min 或 A.nmax=A.max ;
#include<iostream> #include<cstdio> using namespace std; #define Max(a,b) a>b ? a:b #define Min(a,b) a<b ? a:b #define MAX_INT 100000 struct node { int sum; int nmax,nmin; int min,max; int lmax,lmin; int rmax,rmin; int left; int right; }; node interval[10*MAX_INT]; int data[MAX_INT]; int init(int k,int i) { interval[i].left=interval[i].right=k; interval[i].sum=data[k]; interval[i].nmax=interval[i].nmin=data[k]; interval[i].max=interval[i].min=data[k]; interval[i].lmax=interval[i].rmax=data[k]; interval[i].lmin=interval[i].rmin=data[k]; return 0; } int modify(int i) { int k=i<<1; interval[i].left=interval[k].left; interval[i].right=interval[k+1].right; interval[i].sum=interval[k].sum+interval[k+1].sum; interval[i].lmax = Max(interval[k].sum+interval[k+1].lmax , interval[k].lmax); interval[i].rmax= Max(interval[k+1].sum+interval[k].rmax , interval[k+1].rmax); interval[i].lmin= Min(interval[k].sum+interval[k+1].lmin,interval[k].lmin); interval[i].rmin= Min(interval[k+1].sum+interval[k].rmin,interval[k+1].rmin); interval[i].nmax= Max(interval[k].nmax , interval[k+1].nmax); interval[i].nmax= Max(interval[i].nmax , interval[k].rmax+interval[k+1].lmax); interval[i].nmin= Min(interval[k].nmin,interval[k+1].nmin); interval[i].nmin= Min(interval[i].nmin,interval[k].rmin+interval[k+1].lmin); interval[i].min=Min(interval[k].min , interval[k+1].min); interval[i].max=Max(interval[k].max , interval[k+1].max); return 0; } int create(int left,int right,int i) { int mid; if(left==right) { init(left,i); return 0; } mid=(left+right)>>1; create(left,mid,i<<1); create(mid+1,right,(i<<1)+1); modify(i); return 0; } int update(int root,int k,int w) { int i,mid; i=root; while(interval[i].left!=interval[i].right) { mid=(interval[i].left+interval[i].right)>>1; if(mid>=k) i=i<<1; else i=(i<<1)+1; } interval[i].lmax=interval[i].rmax=w; interval[i].lmin=interval[i].rmin=w; interval[i].nmax=interval[i].nmin=w; interval[i].sum=w; interval[i].min=interval[i].max=w; while(i!=root) { i=i>>1; modify(i); } if(interval[root].nmax<interval[root].sum-interval[root].nmin) interval[root].nmax= interval[root].sum-interval[root].nmin; if(interval[root].nmax==interval[root].sum) return interval[root].sum-interval[root].min; if(interval[root].nmin==interval[root].sum) return interval[root].max; return interval[root].nmax; } int main() { int i,k,m,n,w; while(scanf("%d",&n)!=EOF) { for(i=1;i<=n;i++) scanf("%d",&data[i]); create(1,n,1); scanf("%d",&m); for(i=0;i<m;i++) { scanf("%d%d",&k,&w); printf("%d\n",update(1,k,w)); } } return 0; }