洛谷 P3994 高速公路(数据结构斜率优化)

https://www.luogu.org/problemnew/show/P3994

 

设dp[i] 表示第i个城市到根节点的最小花费

dp[i]=min{ (dis[i]-dis[j])*P[i]+Q[i]+dp[j] } 

这是O(n^2)的

这个式子可以斜率优化

dp[i]+dis[j]*P[i]=dis[i]*P[i]+Q[i]+dp[j]

就是一条斜率为P[i]的直线,截(dis[j],dp[j])的最小截距

在根往下走的过程中,斜率单调递增

这就体现了 为什么题目中说“i号城市是j号城市的某个祖先,那么一定存在Pi<=Pj”

我们按dfs序dp

现在唯一的问题就是如何得到 一个点到根节点路径上的单调队列

只需要考虑如何去除兄弟节点的子树对单调队列的影响

即在一个节点退出dfs时,将单调队列恢复为这个节点开始dfs的情况

头指针只是不断的+1,没有涉及到单调队列中元素的修改,所以记录下头指针在哪个位置即可

尾指针涉及到元素的替换,但是它只会替换一个元素,所以记录下尾指针的位置,以及被当前点替换的元素是谁

当节点退出dfs时,恢复记录的这三个值即可

这样的话,一个节点多次出队入队,时间复杂度就不是O(n)了

所以二分出队位置,时间复杂度为O(nlogn)

 

朴素的DP:

#include<cstdio>
#include<iostream>
#include<algorithm>

using namespace std;

#define N 1000001

typedef long long LL;

int P[N],Q[N];

int front[N],to[N<<1],nxt[N<<1],val[N<<1],tot;

int fa[N];

LL dis[N];

int t;
LL mi[N];

void read(int &x)
{
    x=0; char c=getchar();
    while(!isdigit(c)) c=getchar();
    while(isdigit(c)) { x=x*10+c-'0'; c=getchar(); }
}

void add(int u,int v,int w)
{
    to[++tot]=v; nxt[tot]=front[u]; front[u]=tot; val[tot]=w;
    to[++tot]=u; nxt[tot]=front[v]; front[v]=tot; val[tot]=w;
}

void dfs(int x,int f)
{
    for(int i=front[x];i;i=nxt[i])
    {
        if(to[i]==f) continue;
        dis[to[i]]=dis[x]+val[i];
        mi[to[i]]=dis[to[i]]*P[to[i]]+Q[to[i]];
        t=fa[to[i]]; 
        while(t!=1)
        {
            mi[to[i]]=min(mi[to[i]],(dis[to[i]]-dis[t])*P[to[i]]+Q[to[i]]+mi[t]);
            t=fa[t];
        } 
        dfs(to[i],x);
    }
}

int main()
{
    int n,s;
    read(n);
    for(int i=1;i<n;++i)
    {
        read(fa[i+1]); read(s); read(P[i+1]); read(Q[i+1]);
        add(fa[i+1],i+1,s);
    }
    dfs(1,0);
    for(int i=2;i<=n;++i) cout<<mi[i]<<'\n';
}
View Code

 

斜率优化,暴力出队:

#include<cstdio>
#include<iostream>

using namespace std;

#define N 1000001

typedef long long LL;

int front[N],nxt[N],to[N],tot,val[N];

int P[N],Q[N];

int q[N],head,tail;

LL dis[N]; 
LL dp[N];

void read(int &x)
{
    x=0; char c=getchar();
    while(!isdigit(c)) c=getchar();
    while(isdigit(c)) { x=x*10+c-'0'; c=getchar(); }
}

void add(int u,int v,int w)
{
    to[++tot]=v; nxt[tot]=front[u]; front[u]=tot; val[tot]=w;
}

inline double X(int i,int j) { return dis[j]-dis[i]; }
inline double Y(int i,int j) { return dp[j]-dp[i]; } 

void dfs(int x)
{
    int now_h=head,now_t=tail;
    while(head<tail-1 && Y(q[head],q[head+1])<P[x]*X(q[head],q[head+1])) head++; 
    int j=q[head];
    dp[x]=(dis[x]-dis[j])*P[x]+dp[j]+Q[x];
    while(head<tail-1 && Y(q[tail-2],q[tail-1])*X(q[tail-1],x)>X(q[tail-2],q[tail-1])*Y(q[tail-1],x)) tail--;
    int rr=q[tail];
    q[tail++]=x;
    for(int i=front[x];i;i=nxt[i]) 
        dis[to[i]]=dis[x]+val[i],dfs(to[i]);
    head=now_h; q[tail-1]=rr; tail=now_t;
}

int main()
{
    int n;
    read(n);
    int fa,d;
    for(int i=2;i<=n;++i)
    {
        read(fa); read(d);
        add(fa,i,d);
        read(P[i]); read(Q[i]); 
    }
    for(int i=front[1];i;i=nxt[i])
    {
        dis[to[i]]=val[i];
        q[head=0]=1; tail=1;
        dfs(to[i]);
    }
    for(int i=2;i<=n;++i) cout<<dp[i]<<'\n';
}
View Code

 

斜率优化,二分出队

#include<cstdio>
#include<iostream>

using namespace std;

#define N 1000001

typedef long long LL;

int front[N],nxt[N],to[N],tot,val[N];

int P[N],Q[N];

int q[N],head,tail;

LL dis[N]; 
LL dp[N];

void read(int &x)
{
    x=0; char c=getchar();
    while(!isdigit(c)) c=getchar();
    while(isdigit(c)) { x=x*10+c-'0'; c=getchar(); }
}

void add(int u,int v,int w)
{
    to[++tot]=v; nxt[tot]=front[u]; front[u]=tot; val[tot]=w;
}

inline double X(int i,int j) { return dis[j]-dis[i]; }
inline double Y(int i,int j) { return dp[j]-dp[i]; } 

void dfs(int x)
{
    int now_h=head,now_t=tail;
    int l=head,r=tail-2,mid,tmp=-1;
    while(l<=r)
    {
        mid=l+r>>1;
        if(Y(q[mid],q[mid+1])>=P[x]*X(q[mid],q[mid+1])) tmp=mid,r=mid-1;
        else l=mid+1;
    }
    if(tmp!=-1) head=tmp;
    else head=tail-1; 
    int j=q[head];
    dp[x]=(dis[x]-dis[j])*P[x]+dp[j]+Q[x];
    l=head,r=tail-2,tmp=-1;
    while(l<=r)
    {
        mid=l+r>>1;
        if(Y(q[mid],q[mid+1])*X(q[mid+1],x)<=X(q[mid],q[mid+1])*Y(q[mid+1],x)) tmp=mid,l=mid+1;
        else r=mid-1;
    }
    if(tmp!=-1) tail=tmp+2;
    else tail=head+1;
    int rr=q[tail];
    q[tail++]=x;
    for(int i=front[x];i;i=nxt[i]) 
        dis[to[i]]=dis[x]+val[i],dfs(to[i]);
    head=now_h; q[tail-1]=rr; tail=now_t;
}

int main()
{
    int n;
    read(n);
    int fa,d;
    for(int i=2;i<=n;++i)
    {
        read(fa); read(d);
        add(fa,i,d);
        read(P[i]); read(Q[i]); 
    }
    for(int i=front[1];i;i=nxt[i])
    {
        dis[to[i]]=val[i];
        q[head=0]=1; tail=1;
        dfs(to[i]);
    }
    for(int i=2;i<=n;++i) cout<<dp[i]<<'\n';
}
     
View Code

 

posted @ 2018-02-25 08:31  TRTTG  阅读(684)  评论(2编辑  收藏  举报