poj 2750 经典线段树

题意:给出一数组,数组首尾是可以相接的,要求求出最大连续序列值,并且不可以包括所有元素值。

遇到求这种最大连续序列值,看了网上大牛的思想才知道一般分为两种情况:1、不同时包括两端点的情况,这时直接求整个序列中的最大连续序列值;2、同时包括两端点元素,这时求整个序列中的最小连续序列值,然后用总和减去它,就是所要求的值。

具体怎么求法呢?

假设我们将整个序列分成两个连续的序列a,b;与整个序列设成A。假如我们知道a,b序列各个的从左向右最大连续序列值lmax,从右往左的最大连续序列值rmax,和从左向右最小连续序列值lmin,从右往左的最小连续序列值rmin,和每个序列的最大连续序列值nmax和最小连续序列值nmin,和最大元素值max,和最小元素值min;

那么,就有:

A.nmax=max(a.nmax,b.max,a.rmax+b.lmax,);

A.lmax=max(a.sum+b.lmax,a.max);

A.rmax=max(b.sum+a.rmax,b.rmax);

A.lmin=min(a.sum+b.lmin,a.lmin);

A.rmin=min(b.sum+a.rmin,b.rmin);

A.min=min(a.min,b.min) ;

A.max=max(a.max,b.max) ;

现在的A.nmax并不是最终结果,因为还没考虑最大连续序列值存在两端的情况,

A.nmax=max(A.nmax,A.sum-(a.rmin+b.lmin))

这是不是最终的答案呢? 不是,因为没有考虑全为正数和全为负数的情况,所以这种情况下

A.nmax=A.namx-A.min 或 A.nmax=A.max ;

 

#include<iostream>
#include<cstdio>
using namespace std;
#define Max(a,b) a>b ? a:b 
#define Min(a,b) a<b ? a:b 
#define MAX_INT 100000
struct node
{
	int sum;
	int nmax,nmin;
	int min,max;
	int lmax,lmin;
	int rmax,rmin;
	int left;
	int right;
};
node interval[10*MAX_INT];
int data[MAX_INT];
int init(int k,int i)
{
	interval[i].left=interval[i].right=k;
	interval[i].sum=data[k];
	interval[i].nmax=interval[i].nmin=data[k];
	interval[i].max=interval[i].min=data[k];
	interval[i].lmax=interval[i].rmax=data[k];
	interval[i].lmin=interval[i].rmin=data[k];
	return 0;
}
int modify(int i)
{
	int k=i<<1;
	interval[i].left=interval[k].left;
	interval[i].right=interval[k+1].right;

	interval[i].sum=interval[k].sum+interval[k+1].sum;
    
	interval[i].lmax = Max(interval[k].sum+interval[k+1].lmax , interval[k].lmax);
	interval[i].rmax= Max(interval[k+1].sum+interval[k].rmax , interval[k+1].rmax);

	interval[i].lmin= Min(interval[k].sum+interval[k+1].lmin,interval[k].lmin);
	interval[i].rmin= Min(interval[k+1].sum+interval[k].rmin,interval[k+1].rmin);

	interval[i].nmax= Max(interval[k].nmax , interval[k+1].nmax);
    interval[i].nmax= Max(interval[i].nmax , interval[k].rmax+interval[k+1].lmax);
    
	interval[i].nmin= Min(interval[k].nmin,interval[k+1].nmin);
	interval[i].nmin= Min(interval[i].nmin,interval[k].rmin+interval[k+1].lmin);

	interval[i].min=Min(interval[k].min , interval[k+1].min);
	interval[i].max=Max(interval[k].max , interval[k+1].max);
	return 0;
}

int create(int left,int right,int i)
{
	int mid;
	if(left==right)
	{
		init(left,i);
		return 0;
	}
	mid=(left+right)>>1;
	create(left,mid,i<<1);
	create(mid+1,right,(i<<1)+1);
	modify(i);
	return 0;
}
int update(int root,int k,int w)
{
    int i,mid;
	i=root;
	while(interval[i].left!=interval[i].right)
	{
		mid=(interval[i].left+interval[i].right)>>1;
		if(mid>=k)
			i=i<<1;
		else
			i=(i<<1)+1;
	}
	interval[i].lmax=interval[i].rmax=w;
	interval[i].lmin=interval[i].rmin=w;
	interval[i].nmax=interval[i].nmin=w;
	interval[i].sum=w;
	interval[i].min=interval[i].max=w;
	while(i!=root)
	{
		i=i>>1;   modify(i);
	}
	if(interval[root].nmax<interval[root].sum-interval[root].nmin)
		interval[root].nmax= interval[root].sum-interval[root].nmin;
	if(interval[root].nmax==interval[root].sum)
		return interval[root].sum-interval[root].min;
	if(interval[root].nmin==interval[root].sum)
		return interval[root].max;
	return interval[root].nmax;
}
int main()
{
	int i,k,m,n,w;
	while(scanf("%d",&n)!=EOF)
	{
         for(i=1;i<=n;i++)
			 scanf("%d",&data[i]);
		 create(1,n,1);
		 scanf("%d",&m);
		 for(i=0;i<m;i++)
		 {
			 scanf("%d%d",&k,&w);
			 printf("%d\n",update(1,k,w));
		 }
	}
	return 0;
}

 

posted @ 2011-10-29 19:57  书山有路,学海无涯  阅读(1068)  评论(0编辑  收藏  举报