bzoj2288 生日礼物 (线段树)
我当然想选最大的子段和啦 但要选M次 那不一定就是最好的
所以提供一个反悔的选项,我选了一段以后,就把它们乘个-1,然后再选最好的(类似于网络流的思路)
这个可以用线段树来维护,记一个区间包含左端点/右端点的最大值、最小值(因为要乘-1),还有它们的端点位置
然后一直找 直到最大值<=0
1 #include<bits/stdc++.h> 2 #define pa pair<int,int> 3 #define CLR(a,x) memset(a,x,sizeof(a)) 4 #define mp make_pair 5 using namespace std; 6 typedef long long ll; 7 const int maxn=1e5+10,inf=0x3f3f3f3f; 8 9 inline ll rd(){ 10 ll x=0;char c=getchar();int neg=1; 11 while(c<'0'||c>'9'){if(c=='-') neg=-1;c=getchar();} 12 while(c>='0'&&c<='9') x=x*10+c-'0',c=getchar(); 13 return x*neg; 14 } 15 struct Pos{ 16 int v,l,r; 17 Pos (int a=0,int b=0,int c=0){v=a,l=b,r=c;} 18 }; 19 struct Node{ 20 Pos lmas,rmas,lmis,rmis,ma,mi,sum; 21 }tr[maxn<<2]; 22 23 bool laz[maxn<<2]; 24 int N,M,v[maxn]; 25 26 bool operator <(Pos a,Pos b){ 27 return a.v<b.v; 28 } 29 Pos operator +(Pos a,Pos b){ 30 return Pos(a.v+b.v,a.l,b.r); 31 } 32 Pos operator -(Pos x){return Pos(-x.v,x.l,x.r);} 33 34 Node operator +(Node a,Node b){ 35 Node x; 36 x.sum=a.sum+b.sum; 37 x.lmas=max(a.lmas,a.sum+b.lmas); 38 x.rmas=max(b.rmas,a.rmas+b.sum); 39 x.lmis=min(a.lmis,a.sum+b.lmis); 40 x.rmis=min(b.rmis,a.rmis+b.sum); 41 x.ma=max(a.rmas+b.lmas,max(a.ma,b.ma)); 42 x.mi=min(a.rmis+b.lmis,min(a.mi,b.mi)); 43 return x; 44 } 45 46 inline void build(int p,int l,int r){ 47 if(l==r){ 48 tr[p].sum=tr[p].lmas=tr[p].rmas=tr[p].lmis=tr[p].rmis=tr[p].ma=tr[p].mi=Pos(v[l],l,r); 49 }else{ 50 int m=l+r>>1; 51 build(p<<1,l,m),build(p<<1|1,m+1,r); 52 tr[p]=tr[p<<1]+tr[p<<1|1]; 53 } 54 } 55 56 void deal(int p){ 57 laz[p]^=1; 58 Pos t=tr[p].lmas; 59 tr[p].lmas=-tr[p].lmis;tr[p].lmis=-t; 60 t=tr[p].rmas; 61 tr[p].rmas=-tr[p].rmis;tr[p].rmis=-t; 62 t=tr[p].ma; 63 tr[p].ma=-tr[p].mi;tr[p].mi=-t; 64 tr[p].sum=-tr[p].sum; 65 } 66 67 inline void pushdown(int p){ 68 if(!laz[p]) return; 69 deal(p<<1),deal(p<<1|1); 70 laz[p]=0; 71 } 72 73 inline void rever(int p,int l,int r,int x,int y){ 74 if(x<=l&&r<=y) deal(p); 75 else{ 76 pushdown(p); 77 int m=l+r>>1; 78 if(x<=m) rever(p<<1,l,m,x,y); 79 if(y>=m+1) rever(p<<1|1,m+1,r,x,y); 80 tr[p]=tr[p<<1]+tr[p<<1|1]; 81 } 82 } 83 84 int main(){ 85 //freopen("","r",stdin); 86 int i,j,k; 87 N=rd(),M=rd(); 88 for(i=1;i<=N;i++) 89 v[i]=rd(); 90 int ans=0; 91 build(1,1,N); 92 for(i=1;i<=M;i++){ 93 Pos x=tr[1].ma;if(x.v<=0) break; 94 ans+=x.v; 95 rever(1,1,N,x.l,x.r); 96 } 97 printf("%d\n",ans); 98 return 0; 99 }