Description
Input
Output
Sample Input
4
-1 10 -20
2 2 3 4
-1 10 -20
2 2 3 4
Sample Output
9
HINT
很容易想到方程是f[i]=max(f[j]+a*(s[i]-s[j])^2+b*(s[i]-s[j])+c),(0<=j<i);
但是这样n^2会超时,而且方程是1d1d的dp,所以可以考虑斜率优化
假设状态j比状态k优,则有
f[j] + a*(s[i]-s[j])^2 + b*(s[i]-s[j]) + c > f[k] + a*(s[i]-s[k])^2 + b*(s[i]-s[k]) + c
整理得
f[j] - f[k] + a*s[j]^2 - a*s[k]^2 + b*s[k] - b*s[j] - 2*a*s[i]*s[j] + 2*a*s[i]*s[k] > 0
因为2*a*s[i]是常数,所以移到右边去
f[j] - f[k] + a*s[j]^2 - a*s[k]^2 + b*s[k] - b*s[j] > 2*a*(s[j]-s[k])
令g[i]=f[i] + a*s[i]^2 - b*s[i]
则有g[j]-g[k]>2*a*(s[j]-s[k])
移项得(g[j]-g[k])/(s[j]-s[k])>2*a
显然令g[i]为纵坐标,s[i]为横坐标,右边就是j、k对应的点之间连线的斜率。
然后斜率优化搞掉
#include<cstdio> inline int read() { int x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f; } int n,a,b,c,x,t=1,w=1; long long s[1000020]; long long f[1000020]; long long g[1000020]; int q[1000020]; inline long long calc(int x,int y) {return a*(s[x]-s[y])*(s[x]-s[y])+b*(s[x]-s[y])+c;} inline double cal(int x,int y) { double a=(double)g[y]-g[x],b=s[y]-s[x]; return a/b; } int main() { n=read();a=read();b=read();c=read(); for (int i=1;i<=n;i++) { long long x=read(); s[i]=s[i-1]+x; } for (int i=1;i<=n;i++) { int save=2*a*s[i]; while (t<w && cal(q[t],q[t+1])>=save)t++; int now=q[t]; f[i]=(long long) f[now]+a*(s[i]-s[now])*(s[i]-s[now])+b*(s[i]-s[now])+c; g[i]=(long long) f[i]+a*s[i]*s[i]-b*s[i]; while (t+1<=w && cal(q[w],i)>cal(q[w-1],q[w])) w--; q[++w]=i; } printf("%lld",f[n]); }
——by zhber,转载请注明来源