斜率优化dp学习

 http://acm.hdu.edu.cn/showproblem.php?pid=3507

 

解释:当且仅当g【j,k】程璐,j对dp【i】的更新由于k对dp【i】的;

解释:对于第2点,因为单调队列维护的要是斜率上升的数据结构;所以要做此循环

   对于第3点,因为是斜率上升的数据结构,所以如果满足持续有条件成立,就会一直优下去,知道条件不满足(就好像bst里的前驱一样)

复制代码
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
using namespace std;
typedef long long ll;
inline ll read(){
    ll sum=0,x=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){
        if(ch=='-')
            x=0;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9')
        sum=(sum<<1)+(sum<<3)+(ch^48)*1ll,ch=getchar();
    return x?sum:-sum;
}
inline void write(ll x){
    if(x<0)
        putchar('-'),x=-x;
    if(x>9)
        write(x/10);
    putchar(x%10+'0');
}
const int M=5e5+5;
ll dp[M],a[M];
int que[M],n,m;
ll ans(int i,int j){
    return dp[j]+(a[i]-a[j])*(a[i]-a[j])+m;
}
ll zi(int i,int j){
    return dp[i]+a[i]*a[i]-(dp[j]+a[j]*a[j]);
}
ll mu(int i,int j){
    return (a[i]-a[j]);
    
}
int main(){
    while(~scanf("%d%d",&n,&m)){
        for(int i=1;i<=n;i++){
            a[i]=read();
            a[i]+=a[i-1];
        }
        dp[0]=a[0]=0;
        que[1]=0;
        int head=1,tail=1;
        for(int i=1;i<=n;i++){
            while(head+1<=tail&&zi(que[head+1],que[head])<=2ll*a[i]*mu(que[head+1],que[head]))
                head++;
            dp[i]=ans(i,que[head]);
            while(head+1<=tail&&zi(i,que[tail])*mu(que[tail],que[tail-1])<=zi(que[tail],que[tail-1])*mu(i,que[tail]))
                tail--;
            que[++tail]=i;
        }
        write(dp[n]);
        putchar('\n');
    }
    return 0;
}
View Code
复制代码

 

https://www.lydsy.com/JudgeOnline/problem.php?id=1597

复制代码
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cstdio>
using namespace std;
typedef long long ll;
const int M=5e4+4;
struct node{
    ll x,y;
}nod[M];
bool cmp(node a,node b){
    return a.x==b.x?a.y<b.y:a.x<b.x;
}
ll dp[M];
int que[M];
ll zi(int i,int j){
    return dp[i]-dp[j];
}
ll mu(int i,int j){
    return nod[i+1].y-nod[j+1].y;
}
ll ans(int j,int i){
    return dp[j]+nod[j+1].y*nod[i].x;
}

int main(){
    int n;
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
        scanf("%lld%lld",&nod[i].x,&nod[i].y);
    sort(nod+1,nod+1+n,cmp);
    int tot=0;
    for(int i=1;i<=n;i++){
        if(nod[i].y==nod[tot].y)
            continue;
        while(tot&&nod[tot].y<=nod[i].y)
            tot--;
        nod[++tot]=nod[i];
    }
    int head=1,tail=1;
    for(int i=1;i<=tot;i++){
        while(head+1<=tail&&zi(que[head+1],que[head])<=-nod[i].x*mu(que[head+1],que[head]))
            head++;
        dp[i]=ans(que[head],i);
        while(head+1<=tail&&zi(que[tail],que[tail-1])*mu(i,que[tail])<=zi(i,que[tail])*mu(que[tail],que[tail-1]))
            tail--;
        que[++tail]=i;
    }
    printf("%lld\n",dp[tot]);
    return 0;
}
View Code
复制代码

 

https://www.lydsy.com/JudgeOnline/problem.php?id=1010

复制代码
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
using namespace std;
typedef long long ll;
const int M=5e4+4;
ll dp[M],sum[M];
ll C;
int que[M];
ll zi(int j,int k){
    return dp[j]+sum[j]*sum[j]-dp[k]-sum[k]*sum[k];
}
ll mu(int j,int k){
    return sum[j]-sum[k];
}
ll ans(int j,int i){
    return dp[j]+(sum[i]-sum[j]-C)*(sum[i]-sum[j]-C);
}
int main(){
    int n;
    //freopen("testdata.in","r",stdin);
    scanf("%d%lld",&n,&C);
    for(int i=1;i<=n;i++){
        scanf("%lld",&sum[i]);
        sum[i]+=sum[i-1];
    }
    for(int i=1;i<=n;i++)
        sum[i]+=i*1ll;
/*    for(int i=1;i<=n;i++)
        cout<<sum[i]<<" ";*/
    C++;
    memset(dp,0,sizeof(dp));
    int head=1,tail=1;
    for(int i=1;i<=n;i++){
        while(head+1<=tail&&zi(que[head+1],que[head])<=2ll*(sum[i]-C)*mu(que[head+1],que[head]))
            head++;
        dp[i]=ans(que[head],i);
        while(head+1<=tail&&mu(que[tail],que[tail-1])*zi(i,que[tail])<=mu(i,que[tail])*zi(que[tail],que[tail-1]))
            tail--;
        que[++tail]=i;
    }
/*    for(int i=1;i<=n;i++)
        cout<<dp[i]<<" ";*/
        printf("%lld\n",dp[n]);
    return 0;
}
View Code
复制代码

 

f[i]=min(f[j]+Xik=j+1iP[k]k=j+1iXkPk)+Ci

复制代码
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
using namespace std; 
typedef long long ll;
const int M=1e6+6;
struct node{
    ll x,p,c;
}a[M];
bool cmp(node p,node q){
    return p.x<q.x;
}
ll dp[M],sump[M],sumxp[M];
int que[M];
ll zi(int j,int k){
    return dp[j]+sumxp[j]-dp[k]-sumxp[k];
}
ll mu(int j,int k){
    return sump[j]-sump[k];
}
ll ans(int j,int i){
    return dp[j]+a[i].x*(sump[i]-sump[j])-(sumxp[i]-sumxp[j])+a[i].c;
}
int main(){
    int n;
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
        scanf("%lld%lld%lld",&a[i].x,&a[i].p,&a[i].c);
    sort(a+1,a+1+n,cmp);
    for(int i=1;i<=n;i++){
        sump[i]=sump[i-1]+a[i].p;
        sumxp[i]=sumxp[i-1]+a[i].x*a[i].p;
    }
    int head=1,tail=1;
    for(int i=1;i<=n;i++){
        while(head+1<=tail&&zi(que[head+1],que[head])<=a[i].x*mu(que[head+1],que[head]))
            head++;
        dp[i]=ans(que[head],i);
        while(head+1<=tail&&zi(i,que[tail])*mu(que[tail],que[tail-1])<=mu(i,que[tail])*zi(que[tail],que[tail-1]))
            tail--;
        que[++tail]=i;
    }
    printf("%lld\n",dp[n]);
    return 0;
}
View Code
复制代码

 以上都是维护上凸壳的(维护min值的),解法大同小异,知道方程就好办

下面这题是维护下凸的,注意区别(维护max值的)

https://www.luogu.org/problemnew/show/P3628

复制代码
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cstdio>
using namespace std;
typedef long long ll;
const int M=1e6+5;
ll dp[M],sum[M];
int que[M];
ll a,b,c;
ll zi(int j,int k){
    return dp[j]+a*sum[j]*sum[j]-b*sum[j]-(dp[k]+a*sum[k]*sum[k]-b*sum[k]);
}
ll mu(int j,int k){
    return sum[j]-sum[k];
}
ll ans(int j,int i){
    return dp[j]+a*(sum[i]-sum[j])*(sum[i]-sum[j])+b*(sum[i]-sum[j])+c;
}
int main(){
    int n;
    scanf("%d%lld%lld%lld",&n,&a,&b,&c);
    for(int i=1;i<=n;i++)
        scanf("%lld",&sum[i]),sum[i]+=sum[i-1];
    int head=1,tail=1;
    for(int i=1;i<=n;i++){
        while(head+1<=tail&&zi(que[head+1],que[head])>=2ll*a*sum[i]*mu(que[head+1],que[head]))
            head++;
        dp[i]=ans(que[head],i);
        while(head+1<=tail&&zi(que[tail],i)*mu(que[tail-1],que[tail])>=mu(que[tail],i)*zi(que[tail-1],que[tail]))
            tail--;
        que[++tail]=i;
    }
    printf("%lld\n",dp[n]);
    return 0;
}
View Code
复制代码

 有描述次数的题目!!!!!!!!!

http://acm.hdu.edu.cn/showproblem.php?pid=2829

斜率DP

设dp[i][j]表示前i点,炸掉j条边的最小值。j<i

dp[i][j]=min{dp[k][j-1]+cost[k+1][i]}

又由得出cost[1][i]=cost[1][k]+cost[k+1][i]+sum[k]*(sum[i]-sum[k])

cost[k+1][i]=cost[1][i]-cost[1][k]-sum[k]*(sum[i]-sum[k])

代入DP方程

可以得出 y=dp[k][j-1]-cost[1][k]+sum[k]^2

x=sum[k].

斜率sum[i]

复制代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
inline ll read(){
    ll sum=0;
    int x=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){
        if(ch=='-')
            x=0;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9')
        sum=(sum<<1)+(sum<<3)+(ch^48),ch=getchar();
    return x?sum:-sum;
}
inline void write(ll x){
    if(x<0)
        putchar('-'),x=-x;
    if(x>9)
        write(x/10);
    putchar(x%10+'0');
}
const int M=1e3+3;
ll dp[M][M],sum[M],cost[M];
int n,m,l;
int que[M];
ll zi(int j,int k){
    return dp[j][l-1]-cost[j]+sum[j]*sum[j]-(dp[k][l-1]-cost[k]+sum[k]*sum[k]);
}
ll mu(int j,int k){
    return sum[j]-sum[k];
}
ll ans(int j,int i){
    return dp[j][l-1]+cost[i]-cost[j]-sum[j]*sum[i]+sum[j]*sum[j];
}
int main(){
    while(~scanf("%d%d",&n,&m)){
        if(!n&&!m)
            break;
        sum[0]=0,cost[0]=0;
        for(int i=1;i<=n;i++){
            ll x=read();
            sum[i]=sum[i-1]+x;
            cost[i]=cost[i-1]+sum[i-1]*x;
        }
        for(int i=1;i<=n;i++){
            dp[i][0]=cost[i];
            dp[i][i-1]=0;
        }
        for(l=1;l<=m;l++){
            int head=1,tail=1;
            que[++tail]=l;
            for(int i=1;i<=n;i++){
                while(head+1<=tail&&zi(que[head+1],que[head])<=sum[i]*mu(que[head+1],que[head]))
                    head++;
                dp[i][l]=ans(que[head],i);
                while(head+1<=tail&&zi(i,que[tail])*mu(que[tail],que[tail-1])<=mu(i,que[tail])*zi(que[tail],que[tail-1]))
                    tail--;
                que[++tail]=i;
            }
        }
        write(dp[n][m]);
    //    printf("%lld\n",dp[n][m]);
        putchar('\n');
    }
    return 0;
}
View Code
复制代码

 

posted @   starve_to_death  阅读(129)  评论(0编辑  收藏  举报
努力加载评论中...
点击右上角即可分享
微信分享提示