#斜率优化,二分#CF631E Product Sum
题目
有一个数列 \(a\),其权值为 \(\sum_{i=1}^ni*a_i\),
现在可以任意选择其中一个数字扔到任意位置,使权值和最大。
\(n\leq 2*10^5,|a_i|\leq 10^6\)
分析
不妨先将原数列的权值算一遍,那么其实只是让改变的权值尽量大。
设选择的数字为 \(a_i\),选择的位置为 \(j\)。
当 \(j<i\) 时,表示将这个数放在第 \(j\) 个位置,同时 \([j,i)\) 的数往后移。
改变的权值就是 \(s_{i-1}-a_i*i+a_i*j-s_{j-1},j\in [1,i)\)
令 \(j'=j-1,i'=i-1\) 也就是求 \(s_{i'}-a_{i'+1}*i'+a_{i'+1}*j'-s_{j'}\)
当 \(j>i\) 时,表示将这个数放在第 \(j\) 个位置,同时 \((i,j]\) 的数往前移。
改变的权值就是 \(s_i-a_i*i+a_i*j-s_j,j\in (i,n]\)
综上所述,转化为两个式子:
\[\begin{align}\max_{0\leq j<i<n} s_i-a_{i+1}*i+a_{i+1}*j-s_j\\\max_{1\leq i<j\leq n} s_i-a_i*i+a_i*j-s_j\end{align}
\]
以下式为例,若 \(\exists k>j,a_i*k-s_k\geq a_i*j-s_j\),即 \(\frac{s_k-s_j}{k-j}\leq a_i\) 时,将 \(j\) 弹出。
考虑到求的是最大值,那么维护一个上凸壳,理应是斜率单调递减,不过由于倒序实际上具体维护时是单调递增的。
因为 \(a_i\) 不具有单调性,所以在凸壳上面二分即可。
代码
#include <cstdio>
#include <cctype>
#define fz(j,i) (s[i]-s[j])
#define fm(j,i) (i-j)
using namespace std;
const int N=200011; typedef long long lll;
lll a[N],s[N],sum,ans; int q[N],n,head,tail;
int iut(){
int ans=0,f=1; char c=getchar();
while (!isdigit(c)) f=(c=='-')?-f:f,c=getchar();
while (isdigit(c)) ans=ans*10+c-48,c=getchar();
return ans*f;
}
lll max(lll a,lll b){return a>b?a:b;}
int lower(lll x){
int l=head,r=tail;
while (l<r){
int mid=(l+r+1)>>1;
if (fz(q[mid-1],q[mid])<=x*fm(q[mid-1],q[mid])) l=mid;
else r=mid-1;
}
return q[l];
}
int upper(lll x){
int l=head,r=tail;
while (l<r){
int mid=(l+r)>>1;
if (fz(q[mid],q[mid+1])>=x*fm(q[mid],q[mid+1])) r=mid;
else l=mid+1;
}
return q[l];
}
int main(){
n=iut();
for (int i=1;i<=n;++i){
a[i]=iut(),s[i]=s[i-1]+a[i];
sum+=a[i]*i;
}
head=tail=1;
for (int i=1;i<n;++i){
int now=lower(a[i+1]); ans=max(ans,s[i]+(now-i)*a[i+1]-s[now]);
ans=max(ans,s[i]+(q[tail]-i)*a[i+1]-s[q[tail]]);
while (head<tail&&fz(q[tail-1],q[tail])*fm(q[tail],i)>=fz(q[tail],i)*fm(q[tail-1],q[tail])) --tail;
q[++tail]=i;
}
head=tail=1,q[1]=n;
for (int i=n-1;i;--i){
int now=upper(a[i]); ans=max(ans,s[i]+(now-i)*a[i]-s[now]);
while (head<tail&&fz(q[tail-1],q[tail])*fm(q[tail],i)<=fz(q[tail],i)*fm(q[tail-1],q[tail])) --tail;
q[++tail]=i;
}
return !printf("%lld",ans+sum);
}