[2018.12.4]斜率优化(以[Apio2010]特别行动队为例)
斜率优化学会以后好像也不是那么难嘛。。。
以BZOJ1911为例
->在洛谷上查看
设\(s_i\)为前\(i\)个元素的前缀和,\(f_i\)为dp数组。
\(f_i=max\{f_j+a(s_i-s_j)^2+b(s_i-s_j)+c\}\)
\(f_i=max\{f_j+a(s_i^2-2s_is_j+s_j^2)+b(s_i-s_j)+c\}\)
\(f_i=max\{f_j+as_i^2-2as_is_j+as_j^2+bs_i-bs_j+c\}\)
\(=max\{f_j-2as_js_i+as_j^2-bs_j\}+as_i^2+bs_i+c\)
将\(max\)内部变为直线解析式。
令\(k=-2as_j\),\(b=f_j+as_j^2-bs_j\),自变量\(x\)。
原式\(=kx+b\)
由于\(-5\le a\le-1\),\(s_i\)单调递增
所以\(k=-2as_i>0\)且单调递增,即斜率递增。
我们的目标是找到当\(x=s_i\)时,\(y\)最大的一条直线。
如图是目前加入队列的直线,绿色直线为们查找的位置。
紫色部分的边缘便是每个位置的最大值,是一个凸壳。
由于斜率是递增的,所以蓝色直线必然在队尾。
发现我们找到的最大值在红线上,由于红线的斜率大于蓝线的斜率,且\(s_i\)递增(即接下来访问的位置都在绿线的右边),所以实际上蓝线代表的j已经不可能更新任何之后的状态了。
于是我们把它弹掉,然后用目前的队尾(就是之前最大值所在的线)更新答案,再把更新好的线加入队列。
但是还需要考虑一种情况。
就是一条直线上没有任何一段在凸壳上,即被原凸壳上的任意两条直线覆盖。
比如我们又加入了一根绿线。
我们发现红线被蓝线和绿线覆盖了,而这条刚刚被覆盖的直线总是加入新的直线之前的队首,并且当它被队首之后的那条直线和新加入的线覆盖,它也被任意两条直线覆盖。
想想为什么。这里不给出证明。
(其实就是懒得想)
那么如何判断两条线覆盖另一条线?
假设\(l_1:k_1x+b_1\),\(l_2:k_2x+b_2\),\(l_3:k_3x+b_3\),并且\(l_1\),\(l_2\)覆盖\(l_3\),\(\color{red}k_1>k_2\),\(\color{red}k_1>k_3\)。
我们可以找到\(l_1\),\(l_2\)的交点,过这个交点作\(l_3\)的平行线\(l_4\),显然\(l_4\)与当前凸壳相切与该点,如果\(l_4\)在\(l_3\)之上说明\(l_3\)不在凸壳上。
其实就是\(l_1\),\(l_2\)交点在\(l_3\)之上。
所以只需要比较这个交点的\(y\)坐标和\(l_3\)在\(x\)坐标相同时的\(y\)值即可。
具体如下:
先求\(l_1\),\(l_2\)交点。
\(k_1x+b_1=k_2x+b_2\)
解得\(x=\frac{b_2-b_1}{k_1-k_2}\),此时\(k_1x+b_1=k_2x+b_2=k_1\frac{b_2-b_1}{k_1-k_2}+b_1\)
\(l_3\)在\(x\)相等时的\(y=k_3\frac{b_2-b_1}{k_1-k_2}+b_3\)
如果交点在\(l_3\)之上,
\(k_1\frac{b_2-b_1}{k_1-k_2}+b_1\ge k_3\frac{b_2-b_1}{k_1-k_2}+b_3\)
\(\frac{b_2-b_1}{k_1-k_2}+\frac{b_1}{k_1}\ge \frac{k_3}{k_1}\frac{b_2-b_1}{k_1-k_2}+\frac{b_3}{k_1}\)
\(\frac{1}{k_3}\frac{b_2-b_1}{k_1-k_2}+\frac{b_1}{k_1k_3}\ge \frac{1}{k_1}\frac{b_2-b_1}{k_1-k_2}+\frac{b_3}{k_1k_3}\)
\(\frac{1}{k_3}\frac{b_2-b_1}{k_1-k_2}-\frac{1}{k_1}\frac{b_2-b_1}{k_1-k_2}\ge \frac{b_3}{k_1k_3}-\frac{b_1}{k_1k_3}\)
\(\frac{k_1-k_3}{k_1k_3}\frac{b_2-b_1}{k_1-k_2}\ge \frac{b_3-b_1}{k_1k_3}\)
\(\frac{b_2-b_1}{k_1-k_2}\ge \frac{b_3-b_1}{k_1-k_3}\)
因为\(k_1>k_2\),\(k_1>k_3\)
\((b_2-b_1)(k_1-k_3)\ge (b_3-b_1)(k_1-k_2)\)
所以满足上式时,\(l_3\)不在凸壳中。
于是就可以写代码了。
code:
#include<bits/stdc++.h>
using namespace std;
int n,v[1000010],s[1000010],q[1000010],l,r;
long long a,b,c,dp[1000010],lk[1000010],lb[1000010],tv;
void scan(int &x){
x=0;
char c=getchar();
while('0'>c||c>'9')c=getchar();
while('0'<=c&&c<='9')x=x*10+c-'0',c=getchar();
}
long long val(int i,int x){
return lk[i]*x+lb[i];
}
bool cov(int l1,int l2,int l3){//l1,l2 cover l3
return (lb[l2]-lb[l1])*(lk[l1]-lk[l3])>=(lb[l3]-lb[l1])*(lk[l1]-lk[l2]);
}
int main(){
scanf("%d%lld%lld%lld",&n,&a,&b,&c);
for(int i=1;i<=n;i++){
scan(v[i]);
s[i]=s[i-1]+v[i];
}
l=r=1;
q[1]=0;
dp[0]=0;
for(int i=1;i<=n;i++){
while(l<r&&val(q[l],s[i])<=val(q[l+1],s[i]))l++;
dp[i]=val(q[l],s[i])+a*s[i]*s[i]+b*s[i]+c;
lk[i]=-2*a*s[i];
lb[i]=dp[i]+a*s[i]*s[i]-b*s[i];
while(l<r&&cov(i,q[r-1],q[r]))r--;//不能写cov(q[r-1],i,q[r])
q[++r]=i;
}
printf("%lld",dp[n]);
return 0;
}