浅谈斜率优化

如果一个 DP 的转移方程可以写成 fi=min/maxj<i{fj+ai×bj+ci+dj}+C 的形式,那么可以运用斜率优化。

不妨设转移是 min,忽略那个常数 C,设 gi,j=fj+ai×bj+ci+dj,即 fi=minj<igi,j,式子可以化为 fj+dj=ai×bj+gi,jci,设 yj=fj+djk=aixj=bjtj=gi,jci,原式化为 yj=kxj+tj(),这是一个一次函数的形式。

假设 fi 是由 p 转移来的,即 fi=gi,p=minj<igi,j,因为 tj=gi,jci,所以 tp=minj<itj。 注意到 () 式中 k 是一个定值,这说明,如果过每个点 (xj,yj) 画斜率为 k 的直线 lj,则 lpy 轴的截距是最小的,直观地说就是“在最下面”的。

(如图,假设有这些点,我们要画一条斜率为 1 的直线(k=1),则图中那条是最优的,其 y 轴截距是 5,最小)

现在考虑如何快速找到这条最优直线:维护这些点的下凸壳,则与这个凸壳相切的直线是最优的。

(如图,两种相切)

下凸壳的斜率单调递增,切点就是满足切线斜率 左边的斜率 且 < 右边的斜率的点。

维护这个东西需要动态凸包,但是多数情况下并不需要:

  • 如果 x 单调,k 也单调,则决策点 p 只会单向移动,单调队列维护即可。推荐构造 x 递增,因为这样可能比较直观。
  • 如果 x 单调,用单调栈维护凸壳,然后二分即可。
  • 否则才需要动态凸包 / 李超线段树。

事实上推式子的时候一般不需要把常数项写出来,只要搞清楚 x,y,k 就可以了。

注意特判斜率不存在的情况


例题

  1. [ZJOI2007] 仓库建设

fi=minj<i{fj+ci+k=j+1ipk×(xixk)}=minj<i{fj+ci+k=j+1ipk×xipk×xk}

求出 pi 的前缀和 tpi×xi 的前缀和 s,原式化为:

fi=minj<i{fj+ci+xi×(titj)(sisj)}gi,j+sj=xi×tj+ficixi×ti

x=t,k=x​ 均单调递增,所以决策点只会后移,单调队列维护凸壳,时间复杂度 O(n)

注意不要用 double 算斜率,容易因为精度 WA,要用 long double 或者把斜率不等式化成乘法形式。

本题 p 可能 =0,所以不一定在最后一个地方建仓库,并且斜率可能不存在。

#include<bits/stdc++.h>
#define endl '\n'
#define rep(i, s, e) for(int i = s, i##E = e; i <= i##E; ++i)
#define per(i, s, e) for(int i = s, i##E = e; i >= i##E; --i)
#define F first
#define S second
#define int ll
#define gmin(x, y) (x = min(x, y))
#define gmax(x, y) (x = max(x, y))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double f128;
typedef pair<int, int> pii;
constexpr int N = 1e6 + 5;
int n, x[N], p[N], c[N], t[N], s[N], f[N];
int q[N], l = 1, r;
inline int Y(int i) { return f[i] + s[i]; }
inline int X(int i) { return t[i]; }
inline int K(int i) { return x[i]; }
signed main() {
#ifdef ONLINE_JUDGE
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
#endif
    cin >> n;
    rep(i, 1, n) {
        cin >> x[i] >> p[i] >> c[i];
        t[i] = t[i - 1] + p[i];
        s[i] = s[i - 1] + p[i] * x[i];
    }
    q[++r] = 0;
    rep(i, 1, n) {
        while(l < r && Y(q[l + 1]) - Y(q[l]) <= K(i) * (X(q[l + 1]) - X(q[l])))
            ++l;
        int p = q[l];
        f[i] = f[p] + c[i] + x[i] * (t[i] - t[p]) - (s[i] - s[p]);
        while(l < r)
            if(X(q[r]) - X(q[r - 1]) == 0) {
                if(Y(q[r]) - Y(q[r - 1]) > 0) --r;
                else break;
            }
            else if(X(i) - X(q[r]) == 0) {
                if(Y(i) - Y(q[r]) < 0) --r;
                else break;
            }
            else if((Y(q[r]) - Y(q[r - 1])) * (X(i) - X(q[r])) >= (Y(i) - Y(q[r])) * (X(q[r]) - X(q[r - 1]))) 
                --r;
            else break;
        q[++r] = i;
    }
    int tmp = n;
    while(!p[tmp]) --tmp;
    cout << *max_element(f + tmp, f + n + 1) << endl;
    return 0;
}
  1. [SDOI2012] 任务安排

其实我觉得这个 n2 DP 挺难想到的。。。

fi=minj<i{fj+ti×(cicj)+s×(cncj)}gi,js×cj=ti×cj+fiti×cis×cn

其中 tc 是原题中 TC 的前缀和。提前计算了启动机器的代价。

本题中 x=c 单增,但 k=t 不单调,所以需要单调栈 + 二分,时间复杂度 O(nlogn)

注意不等式变号问题

#include<bits/stdc++.h>
#define endl '\n'
#define rep(i, s, e) for(int i = s, i##E = e; i <= i##E; ++i)
#define per(i, s, e) for(int i = s, i##E = e; i >= i##E; --i)
#define F first
#define S second
#define int ll
#define gmin(x, y) (x = min(x, y))
#define gmax(x, y) (x = max(x, y))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double f128;
typedef pair<int, int> pii;
constexpr int N = 3e5 + 5;
int n, s, c[N], t[N], f[N];
int stk[N], tp;
inline int Y(int i) { return f[i] - s * c[i]; }
inline int X(int i) { return c[i]; }
inline int K(int i) { return t[i]; }
// 下凸,斜率单增
inline int find(int k) {
    int l = 1, r = tp;
    while(l < r) {
        int mid = (l + r) / 2;
        if(Y(stk[mid]) - Y(stk[mid + 1]) <= k * (X(stk[mid]) - X(stk[mid + 1]))) 
            // X 的差是负的,挪过来要变号(这里是我的写法问题)
            r = mid;
        else l = mid + 1;
    }
    return stk[l];
}
signed main() {
#ifdef ONLINE_JUDGE
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
#endif
    cin >> n >> s;
    rep(i, 1, n) {
        cin >> t[i] >> c[i];
        t[i] += t[i - 1], c[i] += c[i - 1];
    }
    stk[++tp] = 0;
    rep(i, 1, n) {
        int p = find(K(i));
        f[i] = f[p] + t[i] * (c[i] - c[p]) + s * (c[n] - c[p]);
        while(tp > 1 && (Y(stk[tp]) - Y(stk[tp - 1])) * (X(i) - X(stk[tp])) >=
            (Y(i) - Y(stk[tp])) * (X(stk[tp]) - X(stk[tp - 1])))
            --tp;
        stk[++tp] = i;
    }
    cout << f[n] << endl;
    return 0;
}
  1. [USACO2008Mar] Land Acquisition

注意到当 wi<wjli<ljij 放一组一定不劣,所以从小到大排序后保留有用的值,然后容易得出 DP:

fi=minj<ifj+wi×lj+1gi,j=wi×(lj+1)+fi

单调队列维护即可,时间复杂度 O(nlogn),瓶颈在排序。

被这题调破防了。别去分母了,安心用 long double 吧,128 位的精度还是够的。以及新旧数组不要弄混。

#include<bits/stdc++.h>
#define endl '\n'
#define rep(i, s, e) for(int i = s, i##E = e; i <= i##E; ++i)
#define per(i, s, e) for(int i = s, i##E = e; i >= i##E; --i)
#define F first
#define S second
#define int ll
#define gmin(x, y) (x = min(x, y))
#define gmax(x, y) (x = max(x, y))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double f128;
typedef pair<int, int> pii;
constexpr int N = 5e4 + 5;
int n, m, f[N], q[N], l = 1, r;
pii a[N], b[N];
inline f128 slp(int i, int j) {
    return f128(f[i] - f[j]) / f128(b[j + 1].S - b[i + 1].S);
}
signed main() {
#ifdef ONLINE_JUDGE
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
#endif
    cin >> n;
    rep(i, 1, n) cin >> a[i].F >> a[i].S;
    sort(a + 1, a + n + 1);
    rep(i, 1, n) {
        while(m && a[i].S >= b[m].S) --m;
        b[++m] = a[i];
    }
    q[++r] = 0;
    rep(i, 1, m) {
        while(l < r && slp(q[l], q[l + 1]) <= b[i].F) ++l;
        int p = q[l];
        f[i] = f[p] + b[i].F * b[p + 1].S;
        while(l < r && slp(q[r], q[r - 1]) >= slp(i, q[r])) --r;
        q[++r] = i;
    }
    cout << f[m] << endl;
    return 0;
}
  1. [APIO2010] 特别行动队

求出 x 的前缀和数组 s,容易得到 DP:

fi=maxj<i{fj+a(sisj)2+b(sisj)+c}=maxj<i{fj2asisj+asi2+bsi+asj2bsj+c}

gi,j+asj2bsj=2asisj+fiasi2bsic

二次函数展开还是斜率优化的形式,单调队列维护,时间复杂度 O(n)

#include<bits/stdc++.h>
#define endl '\n'
#define rep(i, s, e) for(int i = s, i##E = e; i <= i##E; ++i)
#define per(i, s, e) for(int i = s, i##E = e; i >= i##E; --i)
#define F first
#define S second
#define int ll
#define gmin(x, y) (x = min(x, y))
#define gmax(x, y) (x = max(x, y))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double f128;
typedef pair<int, int> pii;
constexpr int N = 1e6 + 5;
int n, a, b, c, s[N], f[N], q[N], l = 1, r;
inline int Y(int i) { return f[i] + a * s[i] * s[i] - b * s[i]; }
inline int X(int i) { return s[i]; }
inline int K(int i) { return 2 * a * s[i]; }
inline f128 slp(int i, int j) {
    return f128(Y(i) - Y(j)) / f128(X(i) - X(j)); 
}
signed main() {
#ifdef ONLINE_JUDGE
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
#endif
    cin >> n >> a >> b >> c;
    rep(i, 1, n) cin >> s[i], s[i] += s[i - 1];
    q[++r] = 0;
    rep(i, 1, n) {
        while(l < r && slp(q[l], q[l + 1]) >= K(i)) ++l;
        int p = q[l];
        f[i] = f[p] + a * (s[i] - s[p]) * (s[i] - s[p]) + b * (s[i] - s[p]) + c;
        while(l < r && slp(q[r], q[r - 1]) <= slp(i, q[r])) --r;
        q[++r] = i;
    }
    cout << f[n] << endl;
    return 0;
}
  1. [HNOI2008] 玩具装箱

求出 C 的前缀和数组 s,DP 式显然:

fi=minj<i{fj+(ij1+sisjL)2}=minj<i{fj+[(i+si)(j+sj)(L+1)]2}

vi=i+si,令 LL+1,则有:

fi=minj<i{fj+(vivjL)2}=minj<i{fj2vivj+vi22Lvi+vj2+2Lvj+L2}

单调队列维护,时间复杂度 O(n)

#include<bits/stdc++.h>
#define endl '\n'
#define rep(i, s, e) for(int i = s, i##E = e; i <= i##E; ++i)
#define per(i, s, e) for(int i = s, i##E = e; i >= i##E; --i)
#define F first
#define S second
#define int ll
#define gmin(x, y) (x = min(x, y))
#define gmax(x, y) (x = max(x, y))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double f128;
typedef pair<int, int> pii;
constexpr int N = 5e4 + 5;
int n, L, s[N], f[N], q[N], l = 1, r;
inline int V(int i) { return i + s[i]; }
inline int Y(int i) { return f[i] + V(i) * V(i) + 2 * L * V(i); }
inline int X(int i) { return V(i); }
inline int K(int i) { return 2 * V(i); }
inline f128 slp(int i, int j) {
    return f128(Y(i) - Y(j)) / f128(X(i) - X(j));
}
signed main() {
#ifdef ONLINE_JUDGE
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
#endif
    cin >> n >> L; ++L;
    rep(i, 1, n) cin >> s[i], s[i] += s[i - 1];
    q[++r] = 0;
    rep(i, 1, n) {
        while(l < r && slp(q[l], q[l + 1]) <= K(i)) ++l;
        int p = q[l];
        f[i] = f[p] + (V(i) - V(p) - L) * (V(i) - V(p) - L);
        while(l < r && slp(q[r], q[r - 1]) >= slp(i, q[r])) --r;
        q[++r] = i;
    }
    cout << f[n] << endl;
    return 0;
}
posted @   untitled0  阅读(53)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· DeepSeek 开源周回顾「GitHub 热点速览」
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· AI与.NET技术实操系列(二):开始使用ML.NET
· 单线程的Redis速度为什么快?
点击右上角即可分享
微信分享提示