题解 花瓶
令 \(dp[i][j]\) 为在位置 \(i\),上一个区间右端点为 \(j\) 时的最优答案
- 二维意义下维护凸包
转移方程 \(dp[i][j] = \max\limits_{0\leqslant k<j} \{dp[j][k]+(s_i-s_j)(s_j-s_k)\}\)
考虑对每个 \(j\) 维护一个凸包,用 \(j\) 去更新可行的 \(i\)
得到 \(s_i-s_j < \frac{dp[j][a]-dp[j][b]}{s_a-s_b}\)
此时如果 \(s\) 无序,发现此时 \(i\) 可以打乱顺序,则可以对 \(s_i\) 排序,于是就可以正常斜率优化 - 关于凸包上的点的斜率计算:当两点纵坐标相同时 \(k=0\),当横坐标相同时根据第二个点的纵坐标更高/低为 \(inf/-inf\)
Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 5010
#define ll long long
#define fir first
#define sec second
#define make make_pair
#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n;
int a[N];
namespace force{
ll ans;
void solve() {
int lim=1<<n;
for (int i=0; i<n; ++i) a[i]=a[i+1];
for (int s=1; s<lim; ++s) if (s&1) {
ll sum=0, lst=0, now=0;
for (int i=n-1; ~i; --i) {
if (s&(1<<i)) {
now+=a[i];
sum+=now*lst;
lst=now;
now=0;
}
else now+=a[i];
}
ans=max(ans, sum);
}
printf("%lld\n", ans);
exit(0);
}
}
namespace task{
const double eps=1e-7;
int sum[N], p[N], dp[N][N], q[N], ans;
inline double calc(int* dp, int s, int t) {
if (sum[s]==sum[t]) return dp[s]>dp[t]?-INF:INF;
return dp[s]==dp[t]?0.0:(1.0*dp[s]-dp[t])/(1.0*sum[s]-sum[t]);
}
void solve() {
memset(dp, 128, sizeof(dp));
for (int i=1; i<=n; ++i) sum[i]=sum[i-1]+a[i], p[i]=i;
for (int i=1; i<=n; ++i) dp[i][0]=0;
sort(p, p+n+1, [](int i, int j){return sum[i]>sum[j];});
// cout<<"p: "; for (int i=0; i<=n; ++i) cout<<p[i]<<' '; cout<<endl;
// cout<<"sum: "; for (int i=0; i<=n; ++i) cout<<sum[p[i]]<<' '; cout<<endl;
for (int j=1; j<=n; ++j) {
// cout<<"j: "<<j<<endl;
int l=1, r=0;
for (int k=n; ~k; --k) if (p[k]<j) {
// cout<<"k: "<<k<<' '<<p[k]<<' '<<sum[p[k]]<<endl;
#if 0
while (l<r) {
// if (l<r) cout<<"slope: "<<q[r]<<' '<<p[k]<<' '<<calc(dp[j], q[r-1], q[r])<<' '<<calc(dp[j], q[r], p[k])<<endl;
// cout<<"slp2: "<<sum[q[r]]<<' '<<sum[p[k]]<<endl;
cout<<"cmp: "<<q[r]<<' '<<p[k]<<' '<<dp[j][q[r]]<<' '<<sum[q[r]]<<' '<<dp[j][p[k]]<<' '<<sum[p[k]]<<endl;
if (dp[j][q[r]]==dp[j][p[k]]) {
if (sum[q[r]]>sum[p[k]]) q[r]=p[k];
cout<<"equal: "<<sum[q[r]]<<' '<<sum[p[k]]<<endl;
goto jump;
}
else if (sum[q[r]]==sum[p[k]]) {
cout<<"equal2: "<<dp[j][q[r]]<<' '<<dp[j][p[k]]<<endl;
if (dp[j][q[r]]<dp[j][p[k]]) --r;
else break;
}
else if (calc(dp[j], q[r-1], q[r])<calc(dp[j], q[r], p[k])) --r;
else break;
}
#endif
while (l<r && (calc(dp[j], q[r-1], q[r])<calc(dp[j], q[r], p[k]))) --r;
q[++r]=p[k];
jump: ;
}
// cout<<"q: "; for (int i=l; i<=r; ++i) cout<<q[i]<<' '; cout<<endl;
// cout<<"dp: "; for (int i=l; i<=r; ++i) cout<<dp[j][q[i]]<<' '; cout<<endl;
for (int i=0; i<=n; ++i) if (p[i]>j) {
// cout<<(fabs(calc(dp[j], q[l], q[l+1])) < 5)<<endl;
while (l<r && calc(dp[j], q[l], q[l+1])>(sum[p[i]]-sum[j])) {
// cout<<"sum: "<<calc(dp[j], q[l], q[l+1])<<' '<<sum[p[i]]-sum[j]<<endl;
// cout<<"while2: "<<dp[j][q[l]]<<' '<<dp[j][q[l+1]]<<endl;
// if (sum[q[l]]==sum[q[l+1]]) {
// if (dp[j][q[l]]<dp[j][q[l+1]]) ++l;
// else {
// q[l+1]=q[l]; ++l;
// }
// }
// else ++l;
++l;
}
// cout<<"q: "; for (int i=l; i<=r; ++i) cout<<q[i]<<' '; cout<<endl;
dp[p[i]][j]=dp[j][q[l]]+(sum[p[i]]-sum[j])*(sum[j]-sum[q[l]]);
// printf("dp[%d][%d]=dp[%d][%d]+%d*%d(%d)\n", p[i], j, j, q[l], sum[p[i]]-sum[j], sum[j]-sum[q[l]], dp[p[i]][j]);
}
}
for (int i=0; i<=n; ++i) ans=max(ans, dp[n][i]);
printf("%lld\n", ans);
exit(0);
}
}
signed main()
{
freopen("d.in", "r", stdin);
freopen("d.out", "w", stdout);
n=read();
for (int i=1; i<=n; ++i) a[i]=read();
// force::solve();
task::solve();
return 0;
}