[日常训练]最大M子段和
Description
在长度为\(n\)的序列\(x\)中选出\(m\)段互不相交的子段,求最大子段和.
Input
第一行两个整数\(n,m\).
第二行\(n\)个整数\(x_1-x_n\).
Output
一行一个整数表示最大值.
Sample Input
5 2
10 -1 10 -1 10
Sample Output
29
HINT
\(1\;\leq\;m\;\leq\;n\;\leq\;10^5,|x_i|\;\leq\;10^9\).
Solution
如果序列中正整数个数\(\leq{m}\),直接取最大的\(m\)个数的和即可.
将序列合并成若干个交错的正负段和,如\(-1,-2,3,4,-5,-6\)可以合并成\(-1-2,3+4,-5-6\).
记录所有正数段之和\(sum\),设正数段个数为\(k\),则需要把\(k\)段正数段合并成\(\leq{m}\)段.
对于每一段正数段有两种操作:
\(1.\)舍弃;
\(2.\)与相邻串合并.
可以用堆来实现:将所有段的绝对值扔入堆中,每次取最小的元素\(k\)(如果在当前序列中是左或右没有元素的负数段就跳过),\(sum-k\),将\(k\)与在当前序列中的左右元素合并即可.
#include<cmath>
#include<ctime>
#include<stack>
#include<queue>
#include<cstdio>
#include<vector>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#define N 100005
#define M -100000000000001LL
using namespace std;
typedef long long ll;
struct heap{
ll k;int n;
}h[N],tmp;
ll a[N],num[N],sum,ans;
int lst[N],nxt[N],n,m,cnt;
inline bool cmp(ll x,ll y){
return x>y;
}
inline bool chk(heap x,heap y){
if(x.k!=y.k) return x.k<y.k;
return x.n<y.n;
}
inline void swim(int i){
heap a=h[i];int j=i>>1;
while(j&&chk(a,h[j])){
h[i]=h[j];num[h[i].n]=i;i=j;j>>=1;
}
h[i]=a;num[h[i].n]=i;
}
inline void sink(int i){
heap a=h[i];int j=i<<1;
while(j<=cnt){
if(j<cnt&&chk(h[j+1],h[j])) ++j;
if(chk(a,h[j])) break;
h[i]=h[j];num[h[i].n]=i;i=j;j<<=1;
}
h[i]=a;num[h[i].n]=i;
}
inline void del(int u){
if(!u||u>n) return;
if(num[u]){
if(num[u]==cnt) --cnt;
else{
h[num[u]]=h[cnt--];sink(num[u]);swim(num[u]);
}
}
lst[nxt[u]]=lst[u];nxt[lst[u]]=nxt[u];
}
inline void init(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;++i){
scanf("%lld",&a[i]);
if(a[i]>=0LL){
++cnt;ans+=a[i];
}
}
/*正数个数<=m*/
if(cnt<=m){
sort(a+1,a+1+n,cmp);
for(int i=1;i<=m;++i)
sum+=a[i];
printf("%lld\n",sum);
return;
}
cnt=0;
for(int i=1;i<=n;++i){
if(sum*a[i]<0LL){
a[++cnt]=sum;sum=a[i];
}
else sum+=a[i];
}
a[++cnt]=sum;
n=cnt;cnt=0;
for(int i=1;i<=n;++i)
if(a[i]>0) ++cnt;
/*正数区间数<=m*/
if(cnt<=m){
printf("%lld\n",ans);return;
}
for(int i=1;i<=n;++i){
lst[i]=i-1;nxt[i]=i+1;
}
lst[n+1]=n;nxt[0]=1;
m=cnt-m;cnt=0;
for(int i=1;i<=n;++i){
h[++cnt].k=abs(a[i]);h[cnt].n=i;num[i]=cnt;swim(cnt);
}
while(m){
tmp=h[1];h[1]=h[cnt--];sink(1);
if(((!lst[tmp.n]||nxt[tmp.n]>n)&&a[tmp.n]<0)||!a[tmp.n]) continue;
ans-=tmp.k;a[tmp.n]+=a[lst[tmp.n]]+a[nxt[tmp.n]];
del(lst[tmp.n]);del(nxt[tmp.n]);
h[++cnt].k=abs(a[tmp.n]);h[cnt].n=tmp.n;swim(cnt);
--m;
}
printf("%lld\n",ans);
}
int main(){
freopen("sequence.in","r",stdin);
freopen("sequence.out","w",stdout);
init();
fclose(stdin);
fclose(stdout);
return 0;
}