斜率dp+cdq分治
写在前面
这个东西应该是一个非常重要的套路......所以我觉得必须写点什么记录一下,免得自己忘掉了
一直以来我的斜率dp都掌握的不算很好......也很少主动地在比赛里想到
写这个的契机是noi.ac在今天的考试中考了一道用这玩意儿的原题,被我搞出来了,于是决定总结一下(毕竟见得越来越多)
斜率dp
考虑一个常见的二次复杂度的dp:
$dp[i]=min(dp[j]+c(i)+g(j)+k(i)*f(j))$
其中$c,g,k,f$都是只和括号里的$i,j$有关的一元函数
一个很重要的思想是:看到n方dp的时候先想想能不能搞成这个样子的式子
如果搞出来了,这个东西一定可以在$O(n\log n)$的时间里面做出来——用cdq分治
怎么cdq
我们先给这四个函数名字:
$c(i)$是额外附加的只和$i$有关的常数
$f(i)=x(i)$作为横坐标
$g(i)=y(i)$作为纵坐标
$k(i)$是$i$这一点上的转移斜率
首先把所有点按照斜率排序
对于过程solve(l,r),这样操作:
首先,按照输入编号,把(l,r)分成两半,然后递归处理solve(l,mid)
返回的是一个按照横坐标排好序的原数组(dp值都知道了的)
我们把这一批东西做一个上凸包(或者下凸包,依照要求max还是min变化)
然后对于后面那一半点我们用前面这个凸包更新答案,一个指针遍历右边一半,另一个指针遍历左边的凸包,每次跳到最优位置为止
这之后,我们递归处理右半部分
最后我们再对这两半归并排序,按照横坐标
什么意义?
实际上这一波操作中,有三个中间被我们排了序的元素:输入编号,斜率,横坐标
实际上就是一个三维偏序:因为不像普通的斜率dp那样横纵坐标或者斜率有单调性,所以我们强行cdq
这样,在每一次更新后一半的时候,前一半都是做完的,而且已经横坐标单调了
朴素n方dp很好看出来,然后发现可以直接套到上面式子里面
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cassert>
#define head DEEP_DARK_FANTASY
#define ll long long
using namespace std;
inline int read(){
int re=0,flag=1;char ch=getchar();
while(!isdigit(ch)){
if(ch=='-') flag=-1;
ch=getchar();
}
while(isdigit(ch)) re=(re<<1)+(re<<3)+ch-'0',ch=getchar();
return re*flag;
}
int n;;
struct node{
ll w,h,c,x,y,k,dp,num;
}a[100010],tmp[100010],q[100010];
inline bool cmp1(node l,node r){
return l.k<r.k;
}
void solve(int l,int r){
if(l==r){
a[l].x=a[l].h;
a[l].y=a[l].dp-a[l].w+a[l].h*a[l].h;
return;
}
int mid=(l+r)>>1,tl,tr,head,tail,i;
tl=tr=0;
for(i=l;i<=r;i++){
if(a[i].num<=mid) tmp[++tl]=a[i];
else q[++tr]=a[i];
}
for(i=l;i<=mid;i++) a[i]=tmp[i-l+1];
for(i=mid+1;i<=r;i++) a[i]=q[i-mid];
solve(l,mid);
head=1,tail=0;
for(i=l;i<=mid;i++){
while(tail>head&&(q[tail].y-q[tail-1].y)*(a[i].x-q[tail].x)>=(q[tail].x-q[tail-1].x)*(a[i].y-q[tail].y)) tail--;
q[++tail]=a[i];
}
tl=1;
for(i=mid+1;i<=r;i++){
while(tl<tail&&a[i].k*(q[tl+1].x-q[tl].x)>=(q[tl+1].y-q[tl].y)) tl++;
a[i].dp=min(a[i].dp,-q[tl].x*a[i].k+q[tl].y+a[i].c);
}
solve(mid+1,r);
tl=l;tr=mid+1;head=0;
while(tl<=mid&&tr<=r){
if(a[tl].x==a[tr].x) tmp[++head]=((a[tl].y>a[tr].y)?a[tr++]:a[tl++]);
else tmp[++head]=((a[tl].x>a[tr].x)?a[tr++]:a[tl++]);;
}
while(tl<=mid) tmp[++head]=a[tl++];
while(tr<=r) tmp[++head]=a[tr++];
for(i=l;i<=r;i++) a[i]=tmp[i-l+1];
}
int main(){
n=read();int i;
for(i=1;i<=n;i++){
a[i].h=read();
a[i].dp=1e18;
a[i].num=i;
}
for(i=1;i<=n;i++){
a[i].w=read();
a[i].w+=a[i-1].w;
}
for(i=1;i<=n;i++){
a[i].c=a[i].h*a[i].h+a[i-1].w;
a[i].k=2ll*a[i].h;
}
a[1].dp=0;
sort(a+1,a+n+1,cmp1);
solve(1,n);
for(i=1;i<=n;i++)
if(a[i].num==n) printf("%lld\n",a[i].dp);
}