DP斜率优化学习笔记

斜率优化

首先,可以进行斜率优化的DP方程式一般式为$dp[i]=\max_{j=1}^{i-1}/\min_{j=1}^{i-1}\{a(i)*x(j)+b(i)*y(j)\}$

其中$a(j)$和$b(j)$都是关于$j$的函数,在$O(1)$时间内可以计算得出

将方程式进行变形

$$dp[i]=a(i)*x(j)+b(i)*y(j)$$

$$dp[i]-a(i)*x(j)=b(i)*y(j)$$

$$y(j)=-\frac{a(i)}{b(i)}x(j)+\frac{dp[i]}{b(i)}$$

我们可以称$y=-\frac{a(i)}{b(i)}x+\frac{dp[i]}{b(i)}$为$i$的特征直线

那么对于$i$这个决策点来说,在$i$之前所有决策点$j$($j<i$)可以看作一个二维平面上的点,横坐标为$x(j)$,纵坐标为$y(j)$,那么i在寻找最优决策点的过程就是用i的特征直线去截平面上的每一个点,求出截距,找到最大/最小的$dp[i]$

但如果直接去做复杂度为$O(n^{2})$

那么可以通过维护在平面上维护凸包(根据具体的斜率正负和$b(i)$的正负决定维护上凸包还是下凸包),直线在平移过程中切到凸包的第一个点就是当前i的最优决策点,那么时间复杂度均摊$O(n)$

以$-\frac{a(i)}{b(i)}>0$,$b(i)>0$,$dp$取最小值为例

 

 

 

直线$EF$不断从截距无限小向上平移,直到截到凸包上的点

那么需要做的是维护凸包

1.若$x(i)$单调,斜率$-\frac{a(i)}{b(i)}$单调

所以在平面上的点是按顺序依次排列,去截的特征直线的斜率不断增加或减少

可以发现特征直线截到凸包上的第一个点记为$p$,那么$p$左边的点和$p$的直线的斜率$k_{1}$,$p$右边的点和$p$的直线的斜率$k_{2}$,特征直线的斜率一定介于$k_{1}$和$k_{2}$之间

那么可以利用单调队列来维护平面上的点的凸包,然后在单调队列队首不断维护当前特征直线的斜率,找到第一个大于或小于(根据凸包是上凸包还是下凸包决定)的决策点

时间复杂度$O(n)$

2.若$x(i)$单调,斜率不单调

仍然维护凸包,但此时特征直线的斜率没有规律,利用上面的结论可以二分凸包上两点的斜率,来找到第一个切到的点

时间复杂度$O(nlogn)$

3.若$x(i)$不单调,斜率不单调

对无规律的三个维度进行CDQ分治

具体见货币兑换的题解

 

#include <bits/stdc++.h>
#define min(a,b) (((a)<(b))?(a):(b))
#define max(a,b) (((a)>(b))?(a):(b))
#define eps 1e-9
using namespace std;
const int N=100100;
int n,h,t,q[N];
double s,dp[N];
struct node
{
    double a,b,r,k,x,y;
    int id;
}sh[N];
node p[N];
bool cmp(node a,node b)
{
    return a.k>b.k;
}
double slope(int i,int j)
{
    if (sh[j].x==sh[i].x) return 1e9;
    return (sh[j].y-sh[i].y)/(sh[j].x-sh[i].x);
}
void cdq(int l,int r)
{
    if (l==r)
    {
        dp[l]=max(dp[l],dp[l-1]);
        sh[l].x=sh[l].r*dp[l]/(sh[l].a*sh[l].r+sh[l].b);
        sh[l].y=dp[l]/(sh[l].a*sh[l].r+sh[l].b);
        return;
    }
    int mid=(l+r)>>1,tl=l-1,tr=mid;
    for (int i=l;i<=r;++i)
    {
        if (sh[i].id<=mid) p[++tl]=sh[i];
        else p[++tr]=sh[i];
    }
    for (int i=l;i<=r;i++) sh[i]=p[i];
    cdq(l,mid);
    h=0;t=-1;
    for (int i=l;i<=mid;++i)
    {
        while (h<t && slope(q[t],q[t-1])<slope(q[t],i)) t--;
        q[++t]=i;
    }
    for (int i=mid+1;i<=r;++i)
    {
        while (h<t && sh[i].k<slope(q[h],q[h+1])) h++;
        int j=q[h];
        dp[sh[i].id]=max(dp[sh[i].id],sh[i].a*sh[j].x+sh[i].b*sh[j].y);
    }
    cdq(mid+1,r);
    tl=l;tr=mid+1;
    int cnt=l-1;
    while (tl<=mid && tr<=r)
    {
        if (sh[tl].x-sh[tr].x<eps) p[++cnt]=sh[tl],tl++;
        else p[++cnt]=sh[tr],tr++;
    }
    for (int i=tl;i<=mid;++i) p[++cnt]=sh[i];
    for (int i=tr;i<=r;++i) p[++cnt]=sh[i];
    for (int i=l;i<=r;++i) sh[i]=p[i];
}
int main()
{
    scanf("%d%lf",&n,&s);
    for (int i=1;i<=n;i++)
    {
        scanf("%lf%lf%lf",&sh[i].a,&sh[i].b,&sh[i].r);
        sh[i].k=-sh[i].a/sh[i].b;sh[i].id=i;
    }
    sort(sh+1,sh+1+n,cmp);
    dp[0]=s;
    cdq(1,n);
    printf("%.3lf\n",dp[n]);
}
View Code

 

时间复杂度$O(nlogn)$

若在将序列上的问题转化到树上则利用点分治见购票

 

#include <bits/stdc++.h>
#define inf 1e18
#define int long long
#define re register
using namespace std;
const int N=2*1e5+100;
int n,T,sum[N],dp[N];
int h,t,q[N],sz[N],root,vi[N];
int tot,first[N],nxt[N*2],point[N*2];
int son[N],w;
struct node
{
    int fa,s,p,q,l;
}sh[N];
void add_edge(int x,int y)
{
    tot++;
    nxt[tot]=first[x];
    first[x]=tot;
    point[tot]=y;
}
bool cmp(int a,int b)
{
    return sum[a]-sh[a].l>sum[b]-sh[b].l;
}
double slope(int i,int j)
{
    return 1.0*(dp[j]-dp[i])/(1.0*(sum[j]-sum[i]));
}
int find(int l,int r,double k)
{
    if (r<l) return q[h];
    if (slope(q[l],q[l+1])>=k) return q[l];
    while (l<r)
    {
        int mid=l+((r-l+1)>>1);
        if (slope(q[mid],q[mid+1])<k) l=mid;
        else r=mid-1;
    }
    return q[l+1];
}
void dfs(int x)
{
    sz[x]=1;
    for (re int i=first[x];i!=-1;i=nxt[i])
    {
        int u=point[i];
        if (u==sh[x].fa) continue;
        sum[u]=sum[x]+sh[u].s;
        dfs(u);
        sz[x]+=sz[u];
    }
}
void dfs_sz(int x,int fa)
{
    sz[x]=1;
    for (re int i=first[x];i!=-1;i=nxt[i])
    {
        int u=point[i];
        if (u==fa || vi[u]) continue;
        dfs_sz(u,x);
        sz[x]+=sz[u];
    }
}
void dfs_rt(int x,int fa,int tot)
{
    bool bl=1;
    for (re int i=first[x];i!=-1;i=nxt[i])
    {
        int u=point[i];
        if (u==fa || vi[u]) continue;
        dfs_rt(u,x,tot);
        if (sz[u]>tot/2) bl=0;
    }
    if (tot-sz[x]>tot/2) bl=0;
    if (bl) root=x;
}
void dfs_insert(int x,int fa)
{
    son[++w]=x;
    for (re int i=first[x];i!=-1;i=nxt[i])
    {
        int u=point[i];
        if (u==fa || vi[u]) continue;
        dfs_insert(u,x);
    }
}
void divide(int x)
{
    vi[x]=1;
    vector <int> father;
    father.push_back(x);
    if (sh[x].fa && !vi[sh[x].fa])
    {
        for (int i=sh[x].fa;i!=0 && !vi[i];i=sh[i].fa)
          father.push_back(i);
        dfs_sz(sh[x].fa,x);
        dfs_rt(sh[x].fa,x,sz[sh[x].fa]);
        divide(root);
        for (re int i=1;i<(int)father.size();++i)
          if (sum[father[i]]>=sum[x]-sh[x].l)
            dp[x]=min(dp[x],dp[father[i]]+(sum[x]-sum[father[i]])*sh[x].p+sh[x].q);
    }
    w=0;
    for (re int i=first[x];i!=-1;i=nxt[i])
    {
        int u=point[i];
        if (u==sh[x].fa || vi[u]) continue;
        dfs_insert(u,x);
    }
    sort(son+1,son+1+w,cmp);
    h=n+1;t=n;
    for (re int i=1,j=0;i<=w;++i)
    {
        int u=son[i];
        while (j<(int)father.size() && sum[father[j]]>=sum[u]-sh[u].l)
        {
            while (h<t && slope(q[h],q[h+1])<slope(q[h],father[j])) h++;
            q[--h]=father[j];
            j++;
        }
        if (h>t) continue;
        int pos=find(h,t-1,1.0*sh[u].p);
        dp[u]=min(dp[u],dp[pos]+(sum[u]-sum[pos])*sh[u].p+sh[u].q);
    }
    for (re int i=first[x];i!=-1;i=nxt[i])
    {
        int u=point[i];
        if (u==sh[x].fa || vi[u]) continue;
        dfs_sz(u,x);
        dfs_rt(u,x,sz[u]);
        divide(root);
    }
}
signed main()
{
    tot=-1;
    memset(first,-1,sizeof(first));
    memset(nxt,-1,sizeof(nxt));
    scanf("%lld%lld",&n,&T);
    for (re int i=2;i<=n;++i)
    {
        scanf("%lld%lld%lld%lld%lld",&sh[i].fa,&sh[i].s,&sh[i].p,&sh[i].q,&sh[i].l);
        add_edge(sh[i].fa,i);
        add_edge(i,sh[i].fa);
    }
    sh[1].fa=0;dp[1]=0;
    for (re int i=2;i<=n;++i) dp[i]=inf;
    dfs(1);
    dfs_rt(1,0,sz[1]);
    divide(root);
    for (re int i=2;i<=n;++i) printf("%lld\n",dp[i]);
}
View Code

 

posted @ 2020-05-27 08:25  SevenDawns  阅读(250)  评论(3编辑  收藏  举报
浏览器标题切换
浏览器标题切换end