LG P2389 电脑班的裁员
Description
ZZY有独特的裁员技巧:每个同学都有一个考试得分$a_i(-1000 \leq a_i \leq 1000)$,在$n$个同学$(n \leq 500)$中选出不大于$k$段$(k \leq n)$相邻的同学留下,裁掉未被选中的同学,使剩下同学的得分和最大。要特别注意的是,这次考试答错要扣分【不要问我为什么】,所以得分有可能为负。
Solution
对于$n^3$复杂度:
设$f(i,j)$表示当前选到第$i$个数,选取的段数$\leq j$的最大价值,$s_i$表示前缀和,可以分类讨论:
- $i$不选,则$f(i,j)=f(i-1,j)$
- $i$选,则$f(i, j)=\max \left\{f(k, j-1)+s_{i}-s_{k}\right\}(k < i)$
对于$n^2$复杂度:
设$g(i,j-1)=f(k,j-1)-s_k$,在更新$f$的时候顺便更新$g$,可以做到$n^2$的复杂度
对于$n^2$复杂度:
还有别的思路:
设$f(i,j)$表示当前选到第$i$个数,选取的段数$\leq j$,且$i$不一定选择的最大价值,$g(i,j)$表示当前选到第$i$个数,选取的段数$\leq j$,且$i$一定选择的最大价值
更新$g$时,需要讨论$i-1$是否选择,更新$f$时,需要讨论$i$是否选择:
$$g(i, j)=\max \{g(i-1, j), f(i-1, j-1)\}+a_{i}$$
$$f(i, j)=\max \{g(i, j), f(i-1, j)\}$$
实现时可以滚动数组省去一维,空间复杂度$O(n)$
对于$n\log n$复杂度:
可以贪心
对于一段连续的符号相同的数,只有全部选择或全部不选择两种情况:
- 如果一段负数只选择了一半,就可以找到更大的价值(不选这段负数)
- 如果一段正数只选择了一半,也可以找到更大的价值(同时选择另一段)
所以可以将所有连续的、符号相同的数融合成一个数
设当前序列中正数的个数为$cnt$
- $cnt \geq k$时,选择最大的$k$个正数
- $cnt < k$时,需要做出一些放弃
要么选择一个负数,将两个正数连接在一起;要么放弃一个正数
不论哪一种选择,选择对数$x$进行操作,损失的价值都为$|x|$
使用优先队列可以实现
对于$n$复杂度:
贪心还可以改进
假设当前还需要合并$m$个数,那么:
- 大小小于第$m$大的必定要被合并
- 大小大于第$3m$大的必定不会被合并
- 每一轮合并至少合并了$\frac m3$个数
每次合并将三个数合并为一个,所以合并$m$个数最多会影响$3m$个数,假如每次影响的三个数大小都为小于第$m$大,就会合并$\frac m3$个数
总时间复杂度:$n+\frac{2}{3} n+\left(\frac{2}{3}\right)^{2} n+\ldots$
#include<algorithm> #include<iostream> #include<utility> #include<cstdio> #include<queue> #include<cmath> using namespace std; int n,k,l[1000050],r[1000050],l1[1000050],r1[1000050],tot,cnt; long long ans,a[1000050]; bool del[1000050],del1[1000050],in[1000050]; queue<long long>q; pair<long long,int>t[1000050],minn,maxx; inline int read() { int f=1,w=0; char ch=0; while(ch<'0'||ch>'9') { if(ch=='-') f=-1; ch=getchar(); } while(ch>='0'&&ch<='9') { w=(w<<1)+(w<<3)+ch-'0'; ch=getchar(); } return f*w; } void delet1(int x) { if(x) { int L=l1[x],R=r1[x]; l1[R]=L; l1[L]=R; del1[x]=true; } } void delet(int x) { if(x) { --tot; int L=l[x],R=r[x]; l[R]=L; r[L]=R; del[x]=true; delet1(x); } } void add(int x) { if(x&&pair<long long,int>(abs(a[x]),x)<=minn) { in[x]=true; q.push(x); } } void merge(int x) { if(del[x]) { return; } int L=l[x],R=r[x]; if(L&&abs(a[L])<abs(a[x])) { return; } if(R&&abs(a[R])<abs(a[x])) { return; } delet(L); delet(R); a[x]+=a[L]+a[R]; if(L&&R) { add(x); } else { delet(x); } } int main() { n=read(); k=read(); for(int i=1;i<=n;i++) { int x=read(); if(!x) { continue; } if(x<0&&!tot) { continue; } if(tot&&a[tot]<0==x<0) { a[tot]+=x; } else { a[++tot]=x; } } if(tot&&a[tot]<0) { --tot; } for(int i=0;i<=tot;i++) { l[i]=l1[i]=(i+tot)%(tot+1); r[i]=r1[i]=(i+1)%(tot+1); } while(tot>k*2-1) { cnt=0; for(int i=r1[0];i;i=r1[i]) { if(!del1[i]) { t[++cnt]=pair<long long,int>(abs(a[i]),i); } } int temp=(tot-(k*2-1))>>1; nth_element(t+1,t+min(cnt,temp),t+cnt+1); minn=t[min(cnt,temp)]; nth_element(t+1,t+min(cnt,3*temp),t+cnt+1); maxx=t[min(cnt,3*temp)]; for(int i=r1[0];i;i=r1[i]) { pair<long long,int> p(abs(a[i]),i); if(p>maxx) { delet1(i); } else { add(i); } } while(q.size()) { int u=q.front(); q.pop(); in[u]=false; merge(u); } } for(int i=r[0];i;i=r[i]) { if(a[i]>0) { ans+=a[i]; } } printf("%lld\n",ans); return 0; }