The Preliminary Contest for ICPC China Nanchang National Invitational I. Max answer (单调栈+线段树)
题目链接:https://nanti.jisuanke.com/t/38228
题目大意:一个区间的值等于该区间的和乘以区间的最小值。给出一个含有n个数的序列(序列的值有正有负),找到该序列的区间最大值。
样例输入:
5
1 2 3 4 5
样例输出:
36
解题思路:如果序列的值全部为正值的话,可以说很简单,用一个单调栈加前缀和就可以了直接a。但是区间中存在负值,这个问题就变得复杂多了。
首先我们可以用两次单调栈,在O(n)的时间内,对于每个a[i]找到一个最大区间[ l[i] , r[i] ],使得a[i]在这个区间内为最小值。
然后我们便可以枚举每一个a[i],如果a[i]大于0,我们要在区间[ l[i] , r[i] ]内找到一个子区间使得这个区间的和最大,因为这个区间的和越大就可以使得区间的值越大,因为a[i]是区间[ l[i] , r[i] ]的最小值,所以该区间所有值均为正,则子区间的最大和即为[ l[i] , r[i] ]全部数的和,用前缀和便可以求出来了。
但是如果a[i]<0的话,我们就要在[ l[i] , r[i] ]内找到一个子区间使得这个子区间的和最小,这样才能使得区间值最大,我们可以建立两颗线段树,分别维护前缀和的最大值和前缀和的最小值,再在区间[ l[i]-1 , i-1 ]用最大值线段树查找到一个点使得这个点的前缀和最大设最大前缀和为x,再在区间[ i , r[i] ]这个区间内用最小值线段树查找一个点使得这个点的前缀和最小设最小前缀和为y,这样y-x就为区间[ l[i] , r[i] ]内区间和最小的子区间和。
接下来枚举每一个a[i],求出区间值,更新ans就好了。
代码:
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int N=5e5+10; int n,m,q,l[N],r[N]; ll sum[N],a[N]; stack<ll> st; ll tr[2][N*4]; void pushup(int rt){ tr[0][rt]=max(tr[0][rt<<1],tr[0][rt<<1|1]); tr[1][rt]=min(tr[1][rt<<1],tr[1][rt<<1|1]); } void build(int l,int r,int rt){ if(l==r){ tr[0][rt]=tr[1][rt]=sum[l]; return; } int mid=(l+r)/2; build(l,mid,rt*2); build(mid+1,r,rt*2+1); pushup(rt); } ll ask0(int L,int R,int l,int r ,int rt){ //查找[L,R]区间内的最大值 if(L<=l&&R>=r) return tr[0][rt]; ll ans=-1e18; int mid=(l+r)/2; if(mid>=L) ans=max(ans,ask0(L,R,l,mid,rt*2)); if(mid<R) ans=max(ans,ask0(L,R,mid+1,r,rt*2+1)); return ans; } ll ask1(int L,int R,int l,int r ,int rt){ //查找[L,R]区间内的最小值 if(L<=l&&R>=r) return tr[1][rt]; ll ans=1e18; int mid=(l+r)/2; if(mid>=L) ans=min(ans,ask1(L,R,l,mid,rt*2)); if(mid<R) ans=min(ans,ask1(L,R,mid+1,r,rt*2+1)); return ans; } int main(){ cin>>n; for(int i=1;i<=n;i++){ cin>>a[i]; sum[i]=sum[i-1]+a[i]; } build(1,n,1); for(int i=1;i<=n;i++){ //单调栈找左边界 while(st.size()&&a[st.top()]>=a[i])st.pop(); if(st.size()) l[i]=st.top()+1; else l[i]=1; st.push(i); } while(st.size()) st.pop(); for(int i=n;i>=1;i--){ //单调栈找右边界 while(st.size()&&a[st.top()]>=a[i])st.pop(); if(st.size()) r[i]=st.top()-1; else r[i]=n; st.push(i); } ll ans=-1e18; for(int i=1;i<=n;i++){ //枚举每一个a[i] int L=l[i],R=r[i]; if(a[i]<0){ ll x=ask0(max(L-1,1),max(i-1,1),1,n,1); if(L==1&&x<0) x=0; //特判L==1的情况 ll y=ask1(i,R,1,n,1); ans=max(ans,(y-x)*a[i]); }else ans=max(ans,(sum[R]-sum[L-1])*a[i]); } cout<<ans<<endl; return 0; }