POJ-2796 & 2019南昌邀请赛网络赛 I. 区间最大min*sum
http://poj.org/problem?id=2796
https://nanti.jisuanke.com/t/38228
背景
给定一个序列,对于任意区间,min表示区间中最小的数,sum表示区间和,求使得min*sum最大的区间或区间值。
POJ-2796中,序列的值非负,而在网络赛I题中,序列存在负值。
解法分析
直觉上,后者是前者的拓展,我们先考虑序列非负的情况。
非负情况
设序列存储在数组a中。当我们考虑数值a[i]作为区间最小值时,显然我们应该向i的左右两侧扩展并终止于遇到更小的值或者数组越界之前,这样得到的区间可以保证a[i]为最小值且区间和最大。
但是每次都如此求解,复杂度为$O\left ( n^2 \right )$。因此我们尝试用单调栈将其降低到$O\left ( n \right )$。
从我们刚刚的分析中可以看出,最小值是求解区间的关键,同时由于我们是顺序对数组进行处理(以从左往右为例),我们可以利用单调栈,在当遇到a[i]时,很快地找到左侧最近的比它更小的数值(具体而言,构建一个约束为单调递增的栈,当遇到a[i]时,逐个pop掉比它大的数),那么我们可以进一步拓展,借助单调栈维护以a[i]为右端点且为最小值的区间和,我们将栈中节点表示为Node{min,sum}(原单调栈中只存储数值,现在我们要额外存储这个区间的区间和),也就是上文中i向左侧扩展的区间情况,我们将被pop的node.sum求和再加上a[i]就得到了node[i]的左侧区间和了。但是还有i右侧的情况呢,我们先把node[i]压入栈继续处理之后的数字,注意到在随后的处理中当且仅当遇到了第一个比a[i]要小的数字时,node[i]会被pop出来,那么此轮pop中比node[i]先pop出来的若干Node合并在一起显然就是i右侧的区间了,两者进一步合并就得到了以a[i]为最小值,min*sum最大的区间了。需要额外注意的是,一个Node可能右侧没有比它更小的数值,那么在算法最后,需要将栈逐个pop出来做如上操作。
此外POJ-2796还需要计算区间端点位置(如有多个可选区间,任意输出),那么在Node中额外记录一下即可,不做过多说明。
由于每一个元素最多入栈一次,一旦被pop就不会再被查询到,显然复杂度为$O\left ( n \right )$。
代码的main函数,为了迁移到下文的题目中,做了额外修改,本节可以忽略。
#include<cstdio> #include<cstdlib> #include<cstring> #include<string> #include<algorithm> #include<iostream> #include<queue> #include<map> #include<cmath> #include<set> #include<stack> #define LL long long using namespace std; const int N = 100005; LL n, m; LL a[N]; struct node { LL minv; LL sum; LL lef; node(LL mv,LL s,LL l) { minv=mv,sum=s,lef=l; } }; //10 4 8 3 2 6 8 4 9 3 7 int al,ar; LL glo_ans; void cal(int l,int r) { stack<node> s; LL pop_sum,pop_lef; for(int i=l; i<=r; i++) { pop_sum=0; pop_lef=i; while(!s.empty()&&a[i]<=s.top().minv) { node p=s.top(); s.pop(); pop_lef=p.lef; pop_sum+=p.sum; if(pop_sum*p.minv>glo_ans) glo_ans=pop_sum*p.minv,al=p.lef,ar=i-1; //ans=max(ans,pop_sum*p.minv); } s.push(node(a[i],a[i]+pop_sum,pop_lef)); } pop_sum=0; while(!s.empty()) { node p=s.top(); s.pop(); pop_sum+=p.sum; if(pop_sum*p.minv>glo_ans) glo_ans=pop_sum*p.minv,al=p.lef,ar=r; } } int main() { //freopen("in.txt","r",stdin); //freopen("out.txt","w",stdout); //cin.sync_with_stdio(false); int n; while(scanf("%d",&n)!=EOF) { glo_ans=0,al=ar=1; for(int i=1; i<=n; i++) scanf("%lld",&a[i]); int pl=-1,pr=0; for(int i=1; i<=n; i++) { if(a[i]>0) { if(pl==-1) pl=i; pr=i; if(i==n||a[i+1]<=0) { cal(pl,pr); } } else pl=-1; } //cal(1,n); LL tx=0,mx=a[al]; for(int i=al;i<=ar;i++) tx+=a[i],mx=min(mx,a[i]); printf("%lld\n",glo_ans); //cout<<glo_ans<<endl; printf("%d %d\n",al,ar); //cout<<al<<' '<<ar<<endl; } return 0; }
存在负数的情况
存在负数的序列则不能简单地认为可以向左右随意扩张,因为sum不一定随着区间扩张而增长。但经过分析,我们可以发现,无论序列中的数值是什么,最大的min*sum一定是一个非负数,那么只可能有正数乘以正数或负数乘以负数的情况(答案为零的情况无需额外求解,初始值设置为0即可)。
前者,我们可以将序列分成若干非负子区间,用上文的方法求解。
后者,对于任意负数,随着区间的扩张,最小值只会减小不会增大,这样我们只要枚举每一个负数并假设a[i]为最小值,找到区间和最小的区间,求出min*sum,在枚举中保留最大答案即可(与正数的情况不同,我们扩张数组不必担心会使最小值变大,只用考虑区间和的问题),不必担心找到的区间具有更小的最小值,因为我们枚举了每一个负数,那么更小的最小值必然也会被枚举,不会漏掉答案。那么如何找到区间和最小的区间呢?我们可以计算数组的前缀和以及后缀和,i左侧后缀和最小以及i右侧前缀和最小的位置构成的区间既是区间和最小区间,为了快速索引,可以再引入两个数组用于记录i左侧后缀/右侧前缀和最小的位置。
#include <iostream> #include <vector> #include <stack> #define LL long long using namespace std; const int N=500005; LL inf=500000000010; LL pre[N],las[N]; LL idp[N],idl[N]; LL a[N]; int n; struct node { LL minv; LL sum; node(LL mv,LL s) { minv=mv,sum=s; } }; LL glo_ans; void cal(int l,int r) { //cout<<l<<' '<<r<<endl; stack<node> s; LL pop_sum; for(int i=l; i<=r; i++) { pop_sum=0; while(!s.empty()&&a[i]<=s.top().minv) { node p=s.top(); s.pop(); pop_sum+=p.sum; if(pop_sum*p.minv>glo_ans) glo_ans=pop_sum*p.minv; //ans=max(ans,pop_sum*p.minv); } s.push(node(a[i],a[i]+pop_sum)); } pop_sum=0; while(!s.empty()) { node p=s.top(); s.pop(); pop_sum+=p.sum; if(pop_sum*p.minv>glo_ans) glo_ans=pop_sum*p.minv; } } int main() { while(scanf("%d",&n)!=EOF) { LL ans = 0; glo_ans=-inf; for(int i=1; i<=n; i++) { scanf("%lld",&a[i]); idp[i]=idl[i]=i; glo_ans=max(glo_ans,a[i]); } int pl=-1,pr=0; for(int i=1; i<=n; i++) { if(a[i]>0) { if(pl==-1) pl=i; pr=i; if(i==n||a[i+1]<=0) { cal(pl,pr); } } else pl=-1; } fill(las,las+n+1,0); fill(pre,pre+n+1,0); for(int i=1; i<=n; i++) pre[i]=pre[i-1]+a[i]; for(int i=n; i>=1; i--) las[i]=las[i+1]+a[i]; for(int i=2; i<=n; i++) if(las[i]>las[idl[i-1]]) idl[i]=idl[i-1]; for(int i=n-1; i>=1; i--) if(pre[i]>pre[idp[i+1]]) idp[i]=idp[i+1]; for(int i=1; i<=n; i++) { if(a[i]<0) { LL lef=las[idl[i]]-las[i+1]; LL rig=pre[idp[i]]-pre[i-1]; LL sm=lef+rig-a[i]; ans=max(ans,sm*a[i]); } } printf("%lld\n",max(ans,glo_ans)); } return 0; }