DP 的优化

DP 的优化

本文主要介绍 DP 的一些优化方法。

决策单调性优化DP

要学习决策单调性,你首先知道四边形不等式:

四边形不等式

现在有一个函数 w(l,r),若 l1l2r1r2,满足 w(l1,r1)+w(l2,r2)w(l1,r2)+w(l2,r1),则称 w(l,r) 满足四边形不等式,简记为交叉小于包含

特别的,如果等号恒成立,则称 w(l,r) 满足四边形恒等式。另外,若 l1l2r2r1,满足 w(l2,r2)w(l1,r1),则称 w(l,r) 满足区间包含单调性

下面给出几条关于四边形不等式的重要性质:

  • w1(l,r)w2(l,r) 均满足四边形不等式区间包含单调性,那么对于任意 c1,c20,均满足 w(l,r)=c1w1(l,r)+c2w2(l,r) 满足四边形不等式区间单调包含性。证明显然,把拼凑出来的函数的式子拆开就能发现依然满足。
  • l<r,满足 w(l,r)+w(l+1,r+1)w(l,r+1)+w(l+1,r),那么 w(l,r) 满足四边形不等式,证明的话用归纳法推一下即可。

如果你要证一个函数满足四边形不等式,一般就是用第二条性质列出来看一下即可。但是如果是在考场上遇到的题,更常见的是打表看规律,或者直接大胆猜测满足((

那么四边形不等式有什么用呢,马上你就知道了:

四边形不等式优化区间 DP

对于一些区间 DP,我们一般会列出如下的转移式子:

fi,j=minik<j{fi,k+fk+1,j}+w(i,j)

这个 DP 直接做是 O(n3) 的,但是如果 w 满足一些性质,那么可以优化这个 DP。

定理 1:若 w 满足区间包含单调性和四边形不等式,则状态 f(i,j) 满足四边形不等式。

证明(有些证明过程比较繁琐,可以视情况跳过)

不妨设 abcd。下证 f(a,d)+f(b,c)f(a,c)+f(b,d)。考虑依 da 归纳。当 a=bc=d 时,所求即一等式。对于一般的情形,根据 d=opt(a,d) 的位置分类讨论。

第一种情况,cdd<b,即 [b,c] 包含于 [a,d][d+1,d] 之中。

不妨假设 cd,另一种情形同理。此时有

f(a,d)+f(b,c)=f(a,d)+f(d+1,d)+w(a,d)+f(b,c)f(a,c)+f(b,d)+f(d+1,d)+w(a,d)f(a,c)+f(b,d)+f(d+1,d)+w(b,d)f(a,c)+f(b,d).

这里,第一个不等式来自于归纳假设 f(a,c)+f(b,d)f(a,d)+f(b,c),第二个不等式来自于区间包含单调性 w(b,d)w(a,d),第三个不等式来自于最优性条件 f(b,d)f(b,d)+f(d+1,d)+w(b,d)

第二种情况,bd<c,即 d 位于 [b,c] 之中。此时,考虑 c=opt(b,c) 的位置。

不妨假设 cd,即 [b,c] 包含于 [a,d] 之中,另一种情形同理。此时有

f(a,d)+f(b,c)=f(a,d)+f(d+1,d)+w(a,d)+f(b,c)+f(c+1,c)+w(b,c)f(a,c)+f(c+1,c)+w(b,c)+f(b,d)+f(d+1,d)+w(a,d)f(a,c)+f(c+1,c)+w(a,c)+f(b,d)+f(d+1,d)+w(b,d)f(a,c)+f(b,d).

这里,第一个不等式来自于归纳假设 f(a,c)+f(b,d)f(a,d)+f(b,c),第二个不等式来自于四边形不等式 w(a,c)+w(b,d)w(a,d)+w(b,c),第三个不等式来自于 f(a,c)f(b,d) 的最优性条件。

定理 2

w 满足区间包含单调性和四边形不等式,则 f(i,j) 的最优决策点 opt(i,j) 满足

opt(i,j1)opt(i,j)opt(i+1,j).(i+1<j)

证明

上面已经证得 f(i,j) 满足四边形不等式,所以目标函数 f(i,k)+f(k+1,j)+w(i,j) 对于给定 i 作为 (k,j) 的函数满足四边形不等式,所以由定理 1 有,opt(i,j1)opt(i,j)。注意,不同时含有 (k,j) 的项并不影响四边形不等式成立。类似地,它对于给定 j 作为 (k,i) 的函数也满足四边形不等式,所以 opt(i,j)opt(i+1,j)。即得所证。

利用这一结论,我们在区间 DP 时,首先还是枚举区间长度 len,在求 f(i,j) 时暴力搜索 opt(i,j1)opt(i+1,j) 之间的所有 k 求得最优解 f(i,j) 并记录最小最优决策 opt(i,j)。我们发现,对于每一个 len,决策点 k 最多都是从 1 枚举到 n,所以总时间复杂度为 O(n2)

例题:P1880 [NOI1995] 石子合并

函数 w(l,r)=sumrsuml1,这个式子显然满足区间包含单调性和四边形不等式,直接套用上面的做法即可。

代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;
const int N = 205;
int f1[N][N],g[N][N],f2[N][N],a[N],s[N],n;
int w(int l,int r){return s[r]-s[l-1];}
inline int rd()
{
    char c;int f = 1;
    while(!isdigit(c = getchar()))if(c=='-')f = -1;
    int x = c-'0';
    while(isdigit(c = getchar()))x = x*10+(c^48);
    return x*f;
}
int main()
{
    // freopen(".in","r",stdin);
    // freopen(".out","w",stdout);
    n = rd();
    for(int i = 1;i <= n;i++)a[i] = a[i+n] = rd();
    for(int i = 1;i <= n*2;i++)
        s[i] = s[i-1]+a[i],f1[i][i] = 0,g[i][i] = i;
    for(int l = 1;l < n;l++)
        for(int i = 1;i <= n*2-l;i++)
        {
            int j = i+l;
            f1[i][j] = 1e9;
            f2[i][j] = max(f2[i+1][j],f2[i][j-1])+w(i,j);
            for(int k = g[i][j-1];k <= g[i+1][j];k++)
            {
                int now = f1[i][k]+f1[k+1][j]+w(i,j);
                if(now < f1[i][j])
                    g[i][j] = k,f1[i][j] = now;
            }
        }
    int mi = 1e9,mx = 0;
    for(int i = 1;i <= n;i++)
        mi = min(mi,f1[i][i+n-1]),mx = max(mx,f2[i][i+n-1]);
    cout << mi << endl << mx << endl;
    return 0;
}

四边形不等式优化区间划分 DP

区间划分类的 DP,即将区间 [1,n] 划分成很多段区间 [li,ri],每一段的贡献为 w(li,ri),你需要最小化/最大化每一段的贡献之和。

对于任意划分区间(不限制区间个数),我们将在下文的四边形不等式优化 1D/1D DP 中提到,这里要说的是限制了区间个数恰好为 m 时的做法。

我们设 fk,i 表示将前 i 个数划分为 k 段的答案,则有转移方程:

fk,i=min0j<ifk1,j+w(j+1,i)

和上面的区间 DP 一样,我们有如下的定理:

定理 3:若 w 满足四边形不等式,则有 opt(k1,i)opt(k,i)opt(k,i+1)。(opt 的定义与上面相同)

证明

第二个不等式只是第 k 层的决策单调性。关键在于第一个不等式。

下证 opt(k,i)opt(k+1,i)。假设有如下两个区间 [1,i] 的分划(逆序标号):[ak,dk],,[a1,d1][bk+1,ck+1],,[b1,c1]。这里,每个区间的左端点都是其右端点处对应问题的最小最优决策;同样地,从右向左考虑可能的分划,应该有右端点也是左端点对应问题的最小最优决策。例如,djcj 分别是将 [aj,i][bj,i] 分成 j 段左起第一个区间右端点的最小最优决策。根据决策单调性,如果 aj1>bj1,亦即 dj>cj,那么必然有 aj>cj。由此,如果所证不成立,则有 a1>b1。进而可以归纳地证明 ak>bk。这显然与所设矛盾。由此得证。

第一个不等式可以另证如下。同样考虑上面证明中的两个分划。如果所证命题不成立,则有 a1>b1,但是由于有 ak<bk,我们可以找到最小的 j>1 使得 ajbj。进而,此时有 aj1>bj1,故 dj>cj。我们找到了一组区间满足 ajbjcj<dj。考虑将这两个分拆重新组合的结果。考虑分拆 [bk+1,ck+1],,[bj+1,cj+1],[bj,dj],[aj1,dj1],,[a1,d1],共 (k+1) 段,于是由前设的最优性可推知,

w(bk+1,ck+1)++w(bj+1,cj+1)+w(bj,cj)+w(bj1,cj1)++w(b1,c1)w(bk+1,ck+1)++w(bj+1,cj+1)+w(bj,dj)+w(aj1,dj1)++w(a1,d1).

同样地,考虑分拆 [ak,dk],,[aj+1,dj+1],[aj,cj],[bj1,cj1],,[b1,c1],共 k 段,则有

w(ak,dk)++w(aj+1,dj+1)+w(aj,dj)+w(aj1,dj1)++w(a1,d1)<w(ak,dk)++w(aj+1,dj+1)+w(aj,cj)+w(bj1,cj1)++w(b1,c1).

此时,不等号是严格的,因为 a1>b1,但是按假设,a1 是所有 k 段分拆最末一段的左端点中最小最优的。两个不等式条件相加,得到 w(bj,cj)+w(aj,dj)<w(bj,dj)+w(aj,cj),这有悖于四边形不等式。故而原结论得证。

这样,我们可以像区间 DP 一样在一个区间内枚举决策点。具体实现时,应正序枚举 k,倒序枚举 i,然后在区间 [opt(k1,i),opt(k,i+1)] 内枚举决策点 j。时间复杂度依然为 O(n2)

* 注意:这个时间复杂度是 O(n2) 的,不要记成 O(nm) 了。其实严格来讲复杂度应该写为 O(n(n+m))

例题:P4767 [IOI2000] 邮局 加强版

首先自己手玩一下发现 w(l,r)=w(l,r1)+aral+r2(比如四个数,每个 ai 的贡献为 --++,五个数就是 --0++,依此类推)。所以现在就是要证 w(l,r)+w(l+1,r+1)w(l,r+1)+w(l+1,r),拆开式子抵消得:al+r2al+r+12,这个是显然的,所以我们就证得了 w 满足四边形不等式,用上面的方法做即可。

代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;
const int N = 3005,K = 305;
int w[N][N],f[K][N],g[K][N],a[N],n,k;
inline int rd()
{
    char c;int f = 1;
    while(!isdigit(c = getchar()))if(c=='-')f = -1;
    int x = c-'0';
    while(isdigit(c = getchar()))x = x*10+(c^48);
    return x*f;
}
int main()
{
    // freopen(".in","r",stdin);
    // freopen(".out","w",stdout);
    n = rd();k = rd();memset(f,0x3f,sizeof(f));
    for(int i = 1;i <= n;i++)a[i] = rd();
    for(int l = 1;l < n;l++)for(int r = l+1;r <= n;r++)
        w[l][r] = w[l][r-1]+a[r]-a[l+r>>1];
    f[0][0] = 0;
    for(int i = 1;i <= k;i++)for(int j = n;j;j--)
        for(int p = g[i-1][j],up = min(j-1,j==n?n:g[i][j+1]);p <= up;p++)
        {
            int now = f[i-1][p]+w[p+1][j];
            if(now < f[i][j])f[i][j] = now,g[i][j] = p;
        }
    cout << f[k][n] << endl;
    return 0;
}

例题:CF321E Ciel and Gondolas

可以发现 w(l,r) 表示的是 (l,l)(r,r) 的子矩阵的和,很显然 w 是满足四边形不等式的,因为包含比交叉多出两坨东西。设 sumi,j 表示矩阵的前缀和,那么 w 就可以 O(1) 算了。

代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;
const int N = 4005,K = 805;
int s[N][N],f[K][N],g[K][N],n,k;
int w(int j,int i)
{return s[i][i]-s[i][j-1]-s[j-1][i]+s[j-1][j-1]>>1;}
inline int rd()
{
    char c;int f = 1;
    while(!isdigit(c = getchar()))if(c=='-')f = -1;
    int x = c-'0';
    while(isdigit(c = getchar()))x = x*10+(c^48);
    return x*f;
}
int main()
{
    // freopen(".in","r",stdin);
    // freopen(".out","w",stdout);
    n = rd();k = rd();memset(f,0x3f,sizeof(f));
    for(int i = 1;i <= n;i++)for(int j = 1;j <= n;j++)
        s[i][j] = s[i][j-1]+s[i-1][j]-s[i-1][j-1]+rd();
    f[0][0] = 0;
    for(int i = 1;i <= k;i++)for(int j = n;j;j--)
        for(int p = g[i-1][j],up = min(j-1,j==n?n:g[i][j+1]);p <= up;p++)
        {
            int now = f[i-1][p]+w(p+1,j);
            if(now < f[i][j])f[i][j] = now,g[i][j] = p;
        }
    cout << f[k][n] << endl;
    return 0;
}

练习:P4072 [SDOI2016] 征途

*** 这个题下文还会提到另一种做法。**

四边形不等式优化 1D/1D DP(分治)

首先还是来看一个问题,现在有 DP 数组 fi,下面是 fi 的转移方式:

fi=min1jiw(j,i)

这个 DP 的朴素做法是 O(n2) 的,但是如果 w(i,j) 满足四边形不等式,就有更优秀的做法。

我们定义 gi 表示 fi 的最优决策点 j,即 fi 是由 j 转移过来的。若 i1<i2,满足 gi1gi2,则称这个 DP 是满足决策单调性的。

定理 4:若 w 满足四边形不等式,则这个 DP 满足决策单调性。

证明

要证明这一点,可采用反证法。假设对某些 c<d,成立 a=opt(d)<opt(c)=b。此时有 a<bc<d。根据最优化条件,w(a,d)w(b,d)w(b,c)<w(a,c),于是,w(a,d)w(b,d)0<w(a,c)w(b,c),这与四边形不等式矛盾。

实现

接下来考虑具体实现,我们要用到一个很重要的思想:分治

首先我们考虑暴力求出 fn2 的决策点 opt(n2),然后,根据决策单调性,对于 1i<n2 的部分,一定有 opt(i)opt(n2);对于 n2<in 的部分,一定有 opt(n2)opt(i),我们就可以分治下去做了。

设分治函数 solve(l,r,L,R) 表示当前要考虑的区间为 [l,r],决策点的范围为 [L,R],每次找到 mid 的决策点 p,然后再 solve(l,mid1,L,p)solve(mid+1,r,p,R) 即可。

代码:

int w(int l,int r);
void solve(int l,int r,int L,int R)
{
    if(l > r)return ;
    int mid = l+r>>1,j = 0,mi = 1e9;
    for(int i = L;i <= min(mid,R);i++)
    {
        ll now = w(i,mid);
        if(now < mi)mi = now,j = i;
    }
    f[mid] = mi;
    solve(l,mid-1,L,j);
    solve(mid+1,r,j,R);
}

对于每一层,决策点范围都是从 1n,所以总时间复杂度就是 O(nlogn)

例题:P3515 [POI2011] Lightning Conductor

给定一个长度为 n 的序列 {an},对于每个 i[1,n] ,求出一个最小的非负整数 p ,使得 j[1,n],都有 ajai+p|ij|

1n5×1050ai109

首先我们考虑正着做一次,将序列翻转再做一次,两次的结果取 max,这样子就可以去掉绝对值的限制。

于是根据题意,有:

p=max1ji{ajai+ij}

所以 w(j,i)=ajai+ij,下面考虑证 w(l,r) 满足决策单调性(令 d=rl,即区间长度):

w(l,r+1)+w(l+1,r)w(l,r)w(l+1,r+1)=rl+1+rl1rlrl=d+1+d1dd=(d+1d)(dd1)0

因为 x 是上凸的(即斜率单调递减,二阶导恒为负),所以 f(x)=xx1 是单调递减的,所以原式恒小于 0

所以有 w(l,r)+w(l+1,r+1)w(l,r+1)+w(l+1,r),因为这里是取 max,而符号又刚好和四边形不等式相反,所以原 DP 是满足决策单调性的。采用决策单调性分治即可。

代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#define ll long long
using namespace std;
const int N = 5e5+5;
int a[N],n;
double f[N],sqr[N];
double w(int j,int i){return a[j]+sqr[i-j];}
void solve(int l,int r,int L,int R)
{
    if(l > r)return ;
    int mid = l+r>>1,j = 0;
    double mx = 0;
    for(int i = L;i <= min(mid,R);i++)
    {
        double now = w(i,mid);
        if(now > mx)mx = now,j = i;
    }
    f[mid] = max(f[mid],mx);
    solve(l,mid-1,L,j);
    solve(mid+1,r,j,R);
}
inline int rd()
{
    char c;int f = 1;
    while(!isdigit(c = getchar()))if(c=='-')f = -1;
    int x = c-'0';
    while(isdigit(c = getchar()))x = x*10+(c^48);
    return x*f;
}
int main()
{
    // freopen(".in","r",stdin);
    // freopen(".out","w",stdout);
    n = rd();
    for(int i = 1;i <= n;i++)
        a[i] = rd(),sqr[i] = sqrt(i);
    solve(1,n,1,n);
    reverse(a+1,a+n+1);
    reverse(f+1,f+n+1);
    solve(1,n,1,n);
    for(int i = n;i;i--)
        printf("%d\n",(int)ceil(f[i])-a[i]);
    return 0;
}

练习:P4360 [CEOI2004] 锯木厂选址

例题:P4072 [SDOI2016] 征途

这题可以用枚举决策点区间来做,这样子是 O(n2) 的,下面我们用分治的方法来做这道题。

给定一个长为 n 的序列 a,你需要将 a 划分为 m 段,每段的代价为这一段和的平方,使得总代价最小。

mn3000

fk,i 表示前 i 个数划分成 k 段的代价,sumi 表示 a 的前缀和,那么有:

fk,i=min0j<ifk1,j+(sumisumj)2

于是我们可以做 m 次的 DP,每次都是一个决策单调性分治。设当前是第 k 次,则 w(l,r)=fk1,l+(sumrsuml)2。现在就是要证 w(l,r)+w(l+1,r+1)w(l,r+1)+w(l+1,r),把式子展开后消一下项,最后就是 2alar+10,这个是显然的,因为 ai0

然后就可做了。时间复杂度为 O(mnlogn)

代码:作者没写这个做法((

从这题可以看出,对于区间划分类的问题,如果是固定了区间个数,且 w 满足决策单调性,那么都可以有两种写法,一种是枚举决策点,时间复杂度 O(n2),另一种是做 m 次 DP,每次分治来做,时间复杂度 O(mnlogn)。一般来说,如果 n,m 同阶,就用第一种做法;如果 m 远小于 n,那么就用第二种做法。

再看几道不一样的例题:

例题:CF868F Yet Another Minimization Problem

给定一个长为 n 的序列 a,你需要将 a 划分为 m 段,每段的代价是其中相同元素的对数,使得总代价最小。

n105,mmin(20,n)

首先看到 m 很小,就想着用 m 次决策单调性分治,而且易证函数 w 是满足四边形不等式的,但是这道题的难点在于无法快速求出 w(l,r) 的值。

考虑一个类似于莫队的思路。我们维护两个端点 l,r 和一个当前的 w(l,r),然后每次查询 w(l,r) 时就像莫队一样,将两个端点一位一位地平移到要求地区间。这个做法看起来很暴力,但是我们来仔细分析一下复杂度:

考虑分治的过程:左右端点先从父亲区间移到左儿子,再从左儿子区间移到右儿子区间。显然对于每一层,两个端点的移动次数是 O(n) 的,那么总的移动次数是 O(nlogn) 的,所以这个做法的时间复杂度依然是 O(mnlogn) 的。

代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;
const int N = 1e5+5;
int a[N],c[N],n,k,l = 1,r;
ll f[N],g[N],sum;
void add(int x,int v)
{sum += ~v?c[x]:-c[x]+1;c[x] += v;}
ll w(int L,int R)
{
    while(l > L)add(a[--l],1);
    while(r < R)add(a[++r],1);
    while(l < L)add(a[l++],-1);
    while(r > R)add(a[r--],-1);
    return g[L-1]+sum;
}
void solve(int l,int r,int L,int R)
{
    if(l > r)return ;
    int mid = l+r>>1,j = 0;
    ll mi = 1e10;
    for(int i = L;i <= min(mid,R);i++)
    {
        ll now = w(i,mid);
        if(now < mi)mi = now,j = i;
    }
    f[mid] = mi;
    solve(l,mid-1,L,j);
    solve(mid+1,r,j,R);
}
inline int rd()
{
    char c;int f = 1;
    while(!isdigit(c = getchar()))if(c=='-')f = -1;
    int x = c-'0';
    while(isdigit(c = getchar()))x = x*10+(c^48);
    return x*f;
}
int main()
{
    // freopen(".in","r",stdin);
    // freopen(".out","w",stdout);
    n = rd();k = rd()-1;
    for(int i = 1;i <= n;i++)a[i] = rd(),f[i] = w(1,i);
    while(k--)
    {
        for(int i = 1;i <= n;i++)
            g[i] = f[i],c[i] = 0;
        l = 1;r = sum = 0;
        solve(1,n,1,n);
    }
    cout << f[n];
    return 0;
}

例题:P5574 [CmdOI2019] 任务分配问题

给定一个长为 n 的序列 a,你需要将 a 划分为 m 段,每段的代价是其中 i<j,ai<aj 的对数,使得总代价最小。

n2.5×104,mmin(25,n)

跟上题一模一样的思路,做莫队的时候用树状数组维护即可,时间复杂度为 O(mnlog2n)

代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;
const int N = 25005;
int a[N],n,k,l = 1,r;
int t[N],f[N],g[N],sum;
void up(int i,int v){for(;i <= n;i += i&-i)t[i] += v;}
int get(int i)
{int s = 0;for(;i;i -= i&-i)s += t[i];return s;}
void add(int x,int v,bool tp)
{
    int now = tp?get(x-1):get(n)-get(x);
    sum += v*now;up(x,v);
}
ll w(int L,int R)
{
    while(l > L)add(a[--l],1,0);
    while(r < R)add(a[++r],1,1);
    while(l < L)add(a[l++],-1,0);
    while(r > R)add(a[r--],-1,1);
    return g[L-1]+sum;
}
void solve(int l,int r,int L,int R)
{
    if(l > r)return ;
    int mid = l+r>>1,j = 0,mi = 1e9;
    for(int i = L;i <= min(mid-1,R);i++)
    {
        int now = w(i+1,mid);
        if(now < mi)mi = now,j = i;
    }
    f[mid] = mi;
    solve(l,mid-1,L,j);
    solve(mid+1,r,j,R);
}
inline int rd()
{
    char c;int f = 1;
    while(!isdigit(c = getchar()))if(c=='-')f = -1;
    int x = c-'0';
    while(isdigit(c = getchar()))x = x*10+(c^48);
    return x*f;
}
int main()
{
    // freopen(".in","r",stdin);
    // freopen(".out","w",stdout);
    n = rd();k = rd()-1;
    for(int i = 1;i <= n;i++)a[i] = rd(),f[i] = w(1,i);
    while(k--)
    {
        for(int i = 1;i <= n;i++)g[i] = f[i],t[i] = 0;
        l = 1;r = sum = 0;
        solve(1,n,1,n);
    }
    cout << f[n];
    return 0;
}

练习:P10861 [HBCPC2024] MACARON Likes Happy Endings

代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;
const int N = 1e5+5,M = 2e6+5;
int a[N],c[M],n,k,l = 1,r,d;
ll f[N],g[N],sum;
void add(int x,int v)
{sum += ~v?c[x^d]:-c[x^d]+!d;c[x] += v;}
ll w(int L,int R)
{
    L--;
    while(l > L)add(a[--l],1);
    while(r < R)add(a[++r],1);
    while(l < L)add(a[l++],-1);
    while(r > R)add(a[r--],-1);
    return g[L]+sum;
}
void solve(int l,int r,int L,int R)
{
    if(l > r)return ;
    int mid = l+r>>1,j = 0;
    ll mi = 1e10;
    for(int i = L;i <= min(mid,R);i++)
    {
        ll now = w(i,mid);
        if(now < mi)mi = now,j = i;
    }
    f[mid] = mi;
    solve(l,mid-1,L,j);
    solve(mid+1,r,j,R);
}
inline int rd()
{
    char c;int f = 1;
    while(!isdigit(c = getchar()))if(c=='-')f = -1;
    int x = c-'0';
    while(isdigit(c = getchar()))x = x*10+(c^48);
    return x*f;
}
int main()
{
    // freopen(".in","r",stdin);
    // freopen(".out","w",stdout);
    n = rd();k = rd()-1;d = rd();
    for(int i = 1;i <= n;i++)a[i] = a[i-1]^rd(),f[i] = w(1,i);
    while(k--)
    {
        for(int i = 1;i <= n;i++)g[i] = f[i];
        for(int i = 0;i < M;i++)c[i] = 0;
        l = 1;r = sum = 0;
        solve(1,n,1,n);
    }
    cout << f[n];
    return 0;
}

四边形不等式优化 1D/1D DP(分治)

还是来看一个问题,有 DP 数组 fi,下面是 fi 的转移方式:

fi=min0j<i{fj+w(j,i)}

这个 DP 是半在线的,即你不能用分治去做,因为你在求 mid 的决策点时还需要求出前面的 fi,所以是不可做的。在这之前,我们还是要说明这个 DP 的性质:

定理 5:若 w 满足四边形不等式,则这个 DP 满足决策单调性。

证明

pifi 的最优决策点,那么根据定义,0j<pi,满足:

fpi+w(pi,i)fj+w(j,i)(1)

i<in,因为 j<pi<i<i,根据四边形不等式,有:

w(j,i)+w(pi,i)w(j,i)+w(pi,i)(2)

(1),(2) 式相加可得:

fpi+w(pi,i)fj+w(j,i)

所以对于 ipi 之前的决策点都没有 pi 优,所以 pipi,所以有决策单调性。

这时候就可以用二分队列的做法了。

实现

因为决策单调性,我们发现序列中所有决策点为 ji 构成了一段区间,我们考虑维护所有的区间。

现在有一个队列,队列种的每个元素都是一个三元组 [i,l,r],表示 i 可以作为 [l,r] 的决策点。下面是具体流程:

  • 首先向队列加入 [0,1,n],表示当前 0 可以作为所有 fi 的决策点。

对于每个 i,执行以下几个步骤:

  • 如果队头的 r<i,那么就弹出队头,因为队头肯定不再有贡献了(这里最多弹出一个,用 if 判断即可)。
  • 用队头的决策点计算当前的 fi
  • 计算 i 可能成为哪个区间的决策点,首先判断 i 是否比整个队尾的决策点优,即判断是否有 w(i,q[r].l)<=w(q[r].i,q[r].l),如果是,就弹出队尾,一直重复直到不满足条件。
  • 此时队尾的的区间 [l,r] 中一部分的决策点会变成当前的 i,而根据决策点调性,一定存在一个位置 pos,满足 [l,pos) 的决策点都是队尾,[pos,n] 的决策点是 i,那么二分这个 pos 即可,然后将队尾的 r 改为 pos1
  • 如果 posn 的,说明 i 可能会成为后面的决策点,于是将 [i,pos,n] 加入到队尾。

时间复杂度为 O(nlogn)

代码:

int l = 1,r = 1;q[1] = {0,1,n};
for(int i = 1;i <= n;i++)
{
    if(q[l].r < i)l++;
    int j = q[l].i;
    f[i] = w(j,i);g[i] = j;
    while(l <= r&&w(i,q[r].l) <= w(q[r].i,q[r].l))r--;
    int nl = q[r].l,nr = n+1;
    while(nl < nr)
    {
        int mid = nl+nr>>1;
        if(w(i,mid) <= w(q[r].i,mid))nr = mid;
        else nl = mid+1; 
    }
    q[r].r = nl-1;
    if(nl <= n)q[++r] = {i,nl,n};
}

例题:P1912 [NOI2009] 诗人小G

fi 为 DP 数组,leni 表示每个串的长度,sileni+1 的前缀和,则有:

fi=min0j<i{fj+|sisj1L|P}

则有 w(j,i)=|sisj1L|P打表可得发现 w 满足四边形不等式,具体证明需要大力分讨,这里不过多说明。于是直接套用上面提到的二分队列即可。

代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#define ll long double
using namespace std;
const int N = 1e5+5;
int g[N],s[N],n,L,p;
char str[N][35];
ll f[N];
struct node{int i,l,r;}q[N];
ll qp(ll x,int y)
{
    ll ans = 1;
    for(;y;y >>= 1,x = x*x)
        if(y&1)ans *= x;
    return ans;
}
ll w(int j,int i){return f[j]+qp(abs(s[i]-s[j]-1-L),p);}
void pri(int i,char c)
{
    int n = strlen(str[i]);
    for(int j = 0;j < n;j++)putchar(str[i][j]);
    putchar(c);
}
inline int rd()
{
    char c;int f = 1;
    while(!isdigit(c = getchar()))if(c=='-')f = -1;
    int x = c-'0';
    while(isdigit(c = getchar()))x = x*10+(c^48);
    return x*f;
}
int main()
{
    // freopen(".in","r",stdin);
    // freopen(".out","w",stdout);
    for(int t = rd();t--;puts("--------------------"))
    {
        n = rd();L = rd();p = rd();
        for(int i = 1;i <= n;i++)
            scanf("%s",str[i]),s[i] = s[i-1]+strlen(str[i])+1;
        int l = 1,r = 1;q[1] = {0,1,n};
        for(int i = 1;i <= n;i++)
        {
            if(q[l].r < i)l++;
            int j = q[l].i;
            f[i] = w(j,i);g[i] = j;
            while(l <= r&&w(i,q[r].l) <= w(q[r].i,q[r].l))r--;
            int nl = q[r].l,nr = n+1;
            while(nl < nr)
            {
                int mid = nl+nr>>1;
                if(w(i,mid) <= w(q[r].i,mid))nr = mid;
                else nl = mid+1; 
            }
            q[r].r = nl-1;
            if(nl <= n)q[++r] = {i,nl,n};
        }
        if(f[n] > 1e18){puts("Too hard to arrange");continue;}
        int tot = 0;
        for(int i = n;i;i = g[i])s[++tot] = i;
        s[tot+1] = 0;
        printf("%lld\n",(long long)f[n]);
        for(int i = tot;i;i--)
            for(int j = s[i+1]+1;j <= s[i];j++)
                pri(j,j==s[i]?'\n':' ');
    }
    return 0;
}

决策单调性的万能做法:李超线段树

看这里之前你需要先学会李超线段树

如果你觉得决策单调性的各种做法,包括分治,二分队列,二分栈等太杂乱,那么可以考虑使用李超线段树。因为只要是涉及决策单调性的题,如果你能快速求贡献,那就可以使用李超线段树。下面来讲解具体做法:

回想一下李超线段树的定义:有很多条直线,求所有直线在 x 点的最大取值,具体做法是用标记永久化的思想,线段树上每个区间有一条直线,表示这条直线可能成为这个区间的答案。每次新来一条直线时,判断与原来的直线在 l,mid,r 上的取值,然后递归。因为每次递归只可能是单侧递归,所以时间复杂度是 O(nlogn)

考虑怎么把李超线段树用到决策单调性上:现在有很多的决策点,求所有决策点转移到 i 的最小值,那么线段树上每个区间就维护一个数 j,表示 j 可能作为这个区间的决策点。

如果新来一个决策点 k,就是判断分别以 j,k 作为决策点,fl,fmid,fr 的取值。不妨设 j 作为 mid 的决策点比 k 优,那么如果 k 作为 l 的决策点比 j 优,就递归左区间;同理,如果 k 作为 r 的决策点比 j 优,就递归右区间。可以发现,不可能出现 j 作为 mid 的决策点比 k 优,k 作为 l,r 的决策点都比 j 优,这样子就不满足决策单调性了,每次递归都是单侧递归,时间复杂度依然为 O(nlogn),但是这个做法的常数会比较大。

代码:

int t[N << 2];
ll w(int l,int r);
bool cmp(int x,int f,int g)
{
    ll yf = w(f,x),yg = w(g,x);
    return yf != yg?yf < yg:f > g;
}
int minn(int x,int f,int g){return cmp(x,f,g)?f:g;}
void pushtag(int x,int l,int r,int f)
{
    int mid = l+r>>1;
    if(cmp(mid,f,t[x]))swap(f,t[x]);
    if(cmp(l,f,t[x]))pushtag(lson,f);
    else if(cmp(r,f,t[x]))pushtag(rson,f);
}
int query(int x,int l,int r,int i)
{
    if(l == r)return t[x];
    int mid = l+r>>1;
    return minn(i,t[x],i<=mid?query(lson,i):query(rson,i));
}

例题将在后面提到。

斜率优化

咕了,看这篇博客吧。

WQS 二分

学这之前你需要先会斜率优化。

来看这样一道题:

给定一个长为 n 的序列 a,你需要将 a 划分为 m 段,每段的代价为这一段和的平方,使得总代价最小。

mn2×105

P4072 [SDOI2016] 征途的加强版。

这道题用之前说的分治或者是枚举决策点范围的方法都不可做,主要原因是这道题有个限制 m,我们有没有方法能去掉这个限制 m?这时候就要用到 wqs 二分降维了。

wqs 二分用途

wqs 二分一般用于以下的题目:

  • 将一个序列划分成恰好 m,每段有一个代价 w(l,r),求最小的总代价。
  • 如果没有 m 的限制,一般可以 O(n) 或者 O(nlogn) 去做。
  • 设将序列恰好分为 k 段时的答案为 g(i),那么 (i,g(i)) 拟合出的图形是一个凸包。

至于第三点如何去判断,一般是打表,或者大胆猜测,或者使用以下的定理:

定理 6:如果 w 满足四边形不等式,则 g(k) 是一个凸函数。

证明

下证 g(k1)+g(k+1)2g(k)。为此,考虑长度为 (k1) 段和 (k+1) 段的最优分划,分别是 [a1,d1],,[ak1,dk1][b1,c1],,[bk+1,ck+1]。取最小的 1jk1 使得 cj+1dj,其存在性可由 ck<n=dk1 推知。根据其最小性得知,bj+1>aj。所以,aj<bj+1cj+1dj。与上文类似,交换两个现有分拆的后半段,可以得到如下两个区间分拆:

[a1,d1],,[aj1,dj1],[aj,cj+1],[bj+2,cj+2],,[bk+1,ck+1],[b1,c1],,[bj,cj],[bj+1,dj],[aj+1,dj+1],,[ak1,dk1].

两个所得区间都是 k 段的,所以由最优性条件可知

2g(k)w(a1,d1)++w(aj1,dj1)+w(aj,cj+1)+w(bj+2,cj+2)++w(bk+1,ck+1)+w(b1,c1)++w(bj,cj)+w(bj+1,dj)+w(aj+1,dj+1)++w(ak1,dk1)w(a1,d1)++w(aj1,dj1)+w(aj,dj)+w(aj+1,dj+1)++w(ak1,dk1)+w(b1,c1)++w(bj,cj)+w(bj+1,cj+1)+w(bj+2,cj+2)++w(bk+1,ck+1)=g(k1)+g(k+1).

这里第二个不等式正是四边形不等式。所求凸性由此得证。

wqs 二分做法

现在假设有一个 (i,g(i)) 构成的上凸包,我们就是要求 i=m 时的答案。但问题是我们不能很快求出某个点g 的值,也就是说这个凸包的形状是求不出来的,我们只知道它的形状是一个上凸包。

如图(盗一张图):

在这里插入图片描述

既然我们无法求出凸包上某个点的值,我们就考虑用一条斜率为 k直线来切这个凸包。

此时我们可以得到一个值 x,表示这条直线切到了横坐标为 x 的点上(比如两条黑色的线的 x=12):

在这里插入图片描述

于是就可以二分斜率,如果此时切到的点 x<m,即斜率大了,就将斜率调小,反之就将斜率调大,直到切到的点为 m,此时就是答案。

现在考虑怎么计算斜率为 k 的直线会切到哪一个点:

我们假设这条直线经过某个点的直线的表达式为:y=kx+b,那么这条直线切到的点,一定是所有点中截距 b 最大的那一个:

在这里插入图片描述

因为 b=ykx,于是我们就是要找一个 i 使得 g(i)i×k 最大。而 g(i) 的定义为恰好选 i 段,总代价最小。现在减了 i×k,就相当于给每一段的代价减 k,也就是每多选一段,总代价就减 k,同时记录当前选了多少段,最后切到的点 i 就是选了多少段。这个 DP 没有了 m 的限制,于是就好做了。

综上所述,wqs 二分的流程为:

  • 二分一个斜率
  • 计算这个斜率的直线会切到哪一个点。
  • 判断这个点与 m 的关系,并调整斜率。
  • 最后进行一次 check(l),此时会得到斜率为 l 切到的点的截距,应输出 l×m+fn

代码(主函数):

ll l = -1e7,r = 1e7;//具体取决于题目中可能的最大和最小的斜率。
while(l < r)
{
    ll mid = l+r>>1;
    if(check(mid) >= m)r = mid;
    else l = mid+1;
}
check(l);
cout << l*m+f[n] << endl;

wqs 二分的两点注意事项

如果你仔细阅读了上面的过程,你可能会发现几点问题:

  • 最后斜率为 l 的点切到的可能不只是 m,最后返回的可能是切到的另外一个点,可是为什么要输出 l×m+fn
  • 在 check(mid) 时,如果有多个点满足要求,我应该返回哪一个点。

对于第一个问题:

在这里插入图片描述

此时假设 m=3,但算这条直线时 3,4,5,6 都可以是被切到的点。也就是说,一条直线可能会切到一个范围内的点。但此时你会发现,check 这条直线算出的最大截距都是相同的,都是 5。所以你不用管最后 check(l) 切到的点是哪一个,只需要关心这个截距是多少,即 fn。此时我们直接假设切到的点就是 m,那么答案就是 l×m+fn

对于第二种情况,一种做法是用小数二分,但是这很可能会 TLE,其实直接用整数也是可以的。

我们考虑在 check 一条直线时,如果有很多点都满足答案,不要随便返回一个点,比如可以返回最靠右的点。

现在假设我们 check 都返回最靠右的点,即在 check 中如果有多种方案都满足要求,那么选的越多越好。

此时如果 check(mid)m,那么将 r=mid;否则的话,因为连最右边的点都要比 m 小,那么这个点之前肯定都不满足了,所以 l=mid+1

但是如果换一种写法,想想会有什么样的结果,比如 check 返回的是所有点中最左边的点,而二分的部分不变。

首先如果 check(mid)m,则 r=mid,这个是可以的。而如果 check(mid)<m,则 l=mid+1,这一部分就会有问题。

我们假设斜率为 mid 的直线切到了一些点,而 m 正好就在其中,但是此时 check(mid) 返回的是最左边的点,所以有 check(mid)<m,而此时二分的写法 l=mid+1,相当于直接排除掉了斜率为 mid 的直线,也就排除掉了答案。现在再看一下原来的做法,你会发现原来的做法就不存在这样的情况。

也就是说,要么 check 都返回最靠右的点,然后 chech(mid)<m 时有 l=mid+1,要么反过来,只有这两种写法(建议每次都固定一个写法,比如 check 都钦定返回最靠右的点,养成习惯)。同时,也要注意 check 里面各种地方要不要取等,因为你要钦定选最多或最少。

例题:P4983 忘情

这个就直接是征途的加强版,稍微推一下可得每段代价为这一段的和加一的平方。于是直接用 wqs 二分即可,check 的部分使用斜率优化。

代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;
const int N = 1e5+5;
int q[N],g[N],n,m;
ll s[N],f[N];
ll sq(ll x){return x*x;}
ll Y(int j){return f[j]+sq(s[j]);}
ll k(int i){return 2*(s[i]+1);}
ll X(int j){return s[j];}
double slope(int i,int j){return (Y(i)-Y(j))*1.0/(X(i)-X(j));}
int check(ll x)
{
    int l,r;q[l = r = 1] = 0;
    for(int i = 1;i <= n;i++)
    {
        while(l < r&&slope(q[l],q[l+1]) <= k(i))l++;
        int j = q[l];g[i] = g[j]+1;
        f[i] = f[j]+sq(s[i]-s[j]+1)-x;
        while(l < r&&slope(q[r],i) <= slope(q[r-1],q[r]))r--;
        q[++r] = i;
    }
    return g[n];
}
inline int rd()
{
    char c;int f = 1;
    while(!isdigit(c = getchar()))if(c=='-')f = -1;
    int x = c-'0';
    while(isdigit(c = getchar()))x = x*10+(c^48);
    return x*f;
}
int main()
{
    // freopen(".in","r",stdin);
    // freopen(".out","w",stdout);
    n = rd();m = rd();
    for(int i = 1;i <= n;i++)s[i] = s[i-1]+rd();
    ll l = -sq(s[n]+1),r = 0;
    while(l < r)
    {
        ll mid = l+r+1>>1;
        if(check(mid) <= m)l = mid;
        else r = mid-1;
    }
    check(l);
    cout << l*m+f[n];
    return 0;
}

如果你的二分部分写的是上面说的错误的做法,你就会 WA on #10,90分的记录

练习:P4072 [SDOI2016] 征途

在本篇中,针对于这种区间划分的题就已经有了 3 种做法了,分别是分治、枚举决策点范围、wqs 二分,三种做法的复杂度分别为 O(mnlogn)O(n2)O(nlogC)C 是二分的斜率范围)。

代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;
const int N = 1e5+5;
int q[N],g[N],n,m;
ll s[N],f[N];
ll sq(ll x){return x*x;}
ll Y(int j){return f[j]+sq(s[j]);}
ll k(int i){return 2*s[i];}
ll X(int j){return s[j];}
double slope(int i,int j){return (Y(i)-Y(j))*1.0/(X(i)-X(j));}
int check(ll x)
{
    int l,r;q[l = r = 1] = 0;
    for(int i = 1;i <= n;i++)
    {
        while(l < r&&slope(q[l],q[l+1]) <= k(i))l++;
        int j = q[l];g[i] = g[j]+1;
        f[i] = f[j]+sq(s[i]-s[j])-x;
        while(l < r&&slope(q[r],i) <= slope(q[r-1],q[r]))r--;
        q[++r] = i;
    }
    return g[n];
}
inline int rd()
{
    char c;int f = 1;
    while(!isdigit(c = getchar()))if(c=='-')f = -1;
    int x = c-'0';
    while(isdigit(c = getchar()))x = x*10+(c^48);
    return x*f;
}
int main()
{
    // freopen(".in","r",stdin);
    // freopen(".out","w",stdout);
    n = rd();m = rd();
    for(int i = 1;i <= n;i++)s[i] = s[i-1]+rd();
    ll l = -sq(s[n]),r = 0;
    while(l < r)
    {
        ll mid = l+r>>1;
        if(check(mid) >= m)r = mid;
        else l = mid+1;
    }
    check(l);
    cout << (l*m+f[n])*m-sq(s[n]);
    return 0;
}

例题:P6246 [IOI2000] 邮局 加强版 加强版

siai 的前缀和,手推一下有 w(l,r)=(srsl+r+12)(sl+r2sl),也是 wqs 二分即可,check 部分可以使用二分队列来完成(当然李超线段树也可以),下面附上两种做法的代码。

代码(二分队列)
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;
const int N = 5e5+5;
int g[N],n,m;
ll s[N],f[N],k;
struct node{int i,l,r;}q[N];
ll w(int l,int r){return f[l]+s[r]-s[l+r+1>>1]-s[l+r>>1]+s[l]-k;}
int check(ll x)
{
    int l = 1,r = 1;k = x;
    q[1] = {0,1,n};
    for(int i = 1;i <= n;i++)
    {
        if(q[l].r < i)l++;
        int j = q[l].i;
        f[i] = w(j,i);g[i] = g[j]+1;
        while(l <= r&&w(i,q[r].l) <= w(q[r].i,q[r].l))r--;
        int nl = q[r].l,nr = n+1;
        while(nl < nr)
        {
            int mid = nl+nr>>1;
            if(w(i,mid) <= w(q[r].i,mid))nr = mid;
            else nl = mid+1;
        }
        q[r].r = nl-1;
        if(nl <= n)q[++r] = {i,nl,n};
    }
    return g[n];
}
inline int rd()
{
    char c;int f = 1;
    while(!isdigit(c = getchar()))if(c=='-')f = -1;
    int x = c-'0';
    while(isdigit(c = getchar()))x = x*10+(c^48);
    return x*f;
}
int main()
{
    // freopen(".in","r",stdin);
    // freopen(".out","w",stdout);
    n = rd();m = rd();
    for(int i = 1;i <= n;i++)s[i] = s[i-1]+rd();
    ll l = -1e7,r = 0;
    while(l < r)
    {
        ll mid = l+r>>1;
        if(check(mid) >= m)r = mid;
        else l = mid+1;
    }
    check(l);
    cout << l*m+f[n];
    return 0;
}
代码(李超线段树)
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
#define lson x<<1,l,mid
#define rson x<<1|1,mid+1,r
using namespace std;
const int N = 5e5+5;
int t[N << 2],g[N],n,m;
ll s[N],f[N],k;
ll w(int l,int r){return f[l]+s[r]-s[l+r+1>>1]-s[l+r>>1]+s[l];}
bool cmp(int x,int f,int g)
{
    ll yf = w(f,x),yg = w(g,x);
    return yf != yg?yf < yg:f > g;
}
int minn(int x,int f,int g){return cmp(x,f,g)?f:g;}
void pushtag(int x,int l,int r,int f)
{
    int mid = l+r>>1;
    if(cmp(mid,f,t[x]))swap(f,t[x]);
    if(cmp(l,f,t[x]))pushtag(lson,f);
    else if(cmp(r,f,t[x]))pushtag(rson,f);
}
int query(int x,int l,int r,int i)
{
    if(l == r)return t[x];
    int mid = l+r>>1;
    return minn(i,t[x],i<=mid?query(lson,i):query(rson,i));
}
int check(ll x)
{
    k = x;
    for(int i = 1;i <= n*4;i++)t[i] = 0;
    for(int i = 1;i <= n;i++)
    {
        int j = query(1,1,n,i);
        f[i] = w(j,i)-k;g[i] = g[j]+1;
        pushtag(1,1,n,i);
    }
    return g[n];
}
inline int rd()
{
    char c;int f = 1;
    while(!isdigit(c = getchar()))if(c=='-')f = -1;
    int x = c-'0';
    while(isdigit(c = getchar()))x = x*10+(c^48);
    return x*f;
}
int main()
{
    // freopen(".in","r",stdin);
    // freopen(".out","w",stdout);
    n = rd();m = rd();
    for(int i = 1;i <= n;i++)s[i] = s[i-1]+rd();
    ll l = -1e7,r = 0;
    while(l < r)
    {
        ll mid = l+r>>1;
        if(check(mid) >= m)r = mid;
        else l = mid+1;
    }
    check(l);
    cout << l*m+f[n] << endl;
    return 0;
}

练习:P5308 [COCI2018-2019#4] Akvizna

代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;
const int N = 1e5+5;
int q[N],g[N],n,m;
double f[N];
double Y(int j){return f[j];}
double k(int i){return 1.0/i;}
int X(int j){return j;}
double slope(int i,int j){return (Y(i)-Y(j))*1.0/(X(i)-X(j));}
int check(double x)
{
    int l,r;q[l = r = 1] = 0;
    for(int i = 1;i <= n;i++)
    {
        while(l < r&&slope(q[l],q[l+1]) >= k(i))l++;
        int j = q[l];g[i] = g[j]+1;
        f[i] = f[j]+1-j*1.0/i-x;
        while(l < r&&slope(q[r],i) >= slope(q[r-1],q[r]))r--;
        q[++r] = i;
    }
    return g[n];
}
inline int rd()
{
    char c;int f = 1;
    while(!isdigit(c = getchar()))if(c=='-')f = -1;
    int x = c-'0';
    while(isdigit(c = getchar()))x = x*10+(c^48);
    return x*f;
}
int main()
{
    // freopen(".in","r",stdin);
    // freopen(".out","w",stdout);
    n = rd();m = rd();
    double l = 0,r = 1.2e6;
    for(int i = 1;i <= 100;i++)
    {
        double mid = (l+r)/2;
        if(check(mid) >= m)l = mid;
        else r = mid;
    }
    check(l);
    printf("%.9lf\n",l*m+f[n]);
    return 0;
}

练习:CF739E Gosha is hunting

例题:P5633 最小度限制生成树

给你一个有 n 个节点,m 条边的带权无向图,你需要求得一个生成树,使边权总和最小,且满足编号为 s 的节点正好连了 k 条边。

n5×104,m5×105

首先编号为 s 的点每多连一条边,那么这个增长量是越来越小的,所以满足答案是一个下凸包。我们可以把 wqs 二分推广到一般问题上,假设现在是二分斜率 mid,那么在 check(mid) 前,先将所有与点 s 相连的边的权值减 k,然后做 kruskal,最后返回选了多少条与点 s 相连的边。

还有一点就是因为我们要使切到的点尽量靠右,即与 s 相连的边越多越好,所以如果两条边的边权相同,应该优先选与 s 相连的边。

注意,如果直接排序,复杂度是 O(mlogmlogC) 的,但是我们发现每次都是修改与 s 相连的边的权值,于是可以先将与 s 相连的边排一遍序,其它的边排一遍序,每次将所有与 s 相连的边的权值减 mid,然后两部分归并排序即可,复杂度为 O(mlogm+mlogC)

但是作者懒得写归并排序的做法了,直接写的两个 log 的做法,最终 999ms 卡过去了。

代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;
const int N = 5e4+5,M = 5e5+5;
int f[N],n,m,s,k;
ll ans;
int fd(int x){return x==f[x]?x:f[x] = fd(f[x]);}
struct node{int u,v;ll w;}e[M];
bool cmp(node x,node y)
{return x.w != y.w?x.w < y.w:x.u == s&&y.u != s;}
void up(int val)
{for(int i = 1;i <= m;i++)if(e[i].u == s)e[i].w += val;}
int check(int k)
{
    for(int i = 1;i <= n;i++)f[i] = i;
    up(-k);sort(e+1,e+m+1,cmp);
    int cnt = 0;ans = 0;
    for(int i = 1;i <= m;i++)
    {
        int u = fd(e[i].u),v = fd(e[i].v);
        if(u != v){f[u] = v;cnt += e[i].u==s;ans += e[i].w;}
    }
    up(k);
    return cnt;
}
inline int rd()
{
    char c;int f = 1;
    while(!isdigit(c = getchar()))if(c=='-')f = -1;
    int x = c-'0';
    while(isdigit(c = getchar()))x = x*10+(c^48);
    return x*f;
}
int main()
{
    // freopen(".in","r",stdin);
    // freopen(".out","w",stdout);
    n = rd();m = rd();s = rd();k = rd();
    for(int i = 1;i <= n;i++)f[i] = i;
    for(int i = 1;i <= m;i++)
    {
        e[i] = {rd(),rd(),rd()};
        int u = fd(e[i].u),v = fd(e[i].v);
        if(u != v)ans++,f[u] = v;
        if(e[i].v == s)swap(e[i].u,e[i].v);
    }
    int l = -1e9,r = 1e9;
    if(ans != n-1||!(check(l) <= k&&k <= check(r)))
        return puts("Impossible"),0;
    while(l < r)
    {
        int mid = l+r>>1;
        if(check(mid) >= k)r = mid;
        else l = mid+1;
    }
    check(l);
    cout << l*k+ans << endl;
    return 0;
}

练习:P2619 [国家集训队] Tree I

代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;
const int N = 5e4+5,M = 1e5+5;
int f[N],n,m,k;
ll ans;
int fd(int x){return x==f[x]?x:f[x] = fd(f[x]);}
struct node{int u,v;ll w;int c;}e[M];
bool cmp(node x,node y)
{return x.w != y.w?x.w < y.w:x.c < y.c;}
void up(int val)
{for(int i = 1;i <= m;i++)if(!e[i].c)e[i].w += val;}
int check(int k)
{
    for(int i = 1;i <= n;i++)f[i] = i;
    up(-k);sort(e+1,e+m+1,cmp);
    int cnt = 0;ans = 0;
    for(int i = 1;i <= m;i++)
    {
        int u = fd(e[i].u),v = fd(e[i].v);
        if(u != v){f[u] = v;cnt += !e[i].c;ans += e[i].w;}
    }
    up(k);
    return cnt;
}
inline int rd()
{
    char c;int f = 1;
    while(!isdigit(c = getchar()))if(c=='-')f = -1;
    int x = c-'0';
    while(isdigit(c = getchar()))x = x*10+(c^48);
    return x*f;
}
int main()
{
    // freopen(".in","r",stdin);
    // freopen(".out","w",stdout);
    n = rd();m = rd();k = rd();
    for(int i = 1;i <= n;i++)f[i] = i;
    for(int i = 1;i <= m;i++)
        e[i] = {rd()+1,rd()+1,rd(),rd()};
    int l = -5e6,r = 5e6;
    while(l < r)
    {
        int mid = l+r>>1;
        if(check(mid) >= k)r = mid;
        else l = mid+1;
    }
    check(l);
    cout << l*k+ans << endl;
    return 0;
}

* 注意:你写的 cmp 函数必须要满足 cmp(x,x)=0,即自己与自己比较返回 0,不然 sort 会 RE。

例题:[P4383 八省联考 2018] 林克卡特树

转化一下题意,相当于在树上选择 k+1 条链,使得和最大。

首先,每多选一条链,答案的增长量肯定是越来越小的,所以答案是上凸的。

现在题目就是,在树上任意选链,每选一条链答案就会减一个权值,求最大的和。

我们设 DP 数组 fu,0/1/2 表示当前当前点是 u,并且 u 的度数为 0/1/2 的答案,然后分别转移即可。

注意,因为要在 check 时选尽量多的点,所以你可以用一个结构体来存储答案,一结构体内存答案和个数,然后重载小于号,加法。

代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;
const int N = 3e5+5;
int hd[N],cnt,n,k;
struct edge{int to,nex;ll w;}e[N << 1];
void add(int u,int v,int w)
{e[++cnt] = {v,hd[u],w};hd[u] = cnt;}
struct node
{
    ll v;int c;
    friend bool operator < (node x,node y)
    {return x.v != y.v?x.v < y.v:x.c < y.c;}
    friend node operator + (node x,node y)
    {return {x.v+y.v,x.c+y.c};}
    friend node operator + (node x,ll y)
    {return {x.v+y,x.c};}
}f[N][3],tmp;
void dfs(int u,int fa)
{
    f[u][0] = f[u][1] = f[u][2] = {0,0};
    f[u][2] = max(f[u][2],tmp);
    for(int i = hd[u],v;i;i = e[i].nex)
    {
        if((v = e[i].to) == fa)continue;
        dfs(v,u);ll w = e[i].w;
        f[u][2] = max(f[u][2]+f[v][0],f[u][1]+f[v][1]+w+tmp);
        f[u][1] = max(f[u][1]+f[v][0],f[u][0]+f[v][1]+w);
        f[u][0] = f[u][0]+f[v][0];
    }
    f[u][0] = max({f[u][0],f[u][1]+tmp,f[u][2]});
}
int check(ll x){return tmp = {-x,1},dfs(1,0),f[1][0].c;}
inline int rd()
{
    char c;int f = 1;
    while(!isdigit(c = getchar()))if(c=='-')f = -1;
    int x = c-'0';
    while(isdigit(c = getchar()))x = x*10+(c^48);
    return x*f;
}
int main()
{
    // freopen(".in","r",stdin);
    // freopen(".out","w",stdout);
    n = rd();k = rd()+1;
    for(int i = 1;i < n;i++)
    {
        int u = rd(),v = rd(),w = rd();
        add(u,v,w);add(v,u,w);
    }
    check(0);
    ll l = -1e12,r = 1e12;
    while(l < r)
    {
        ll mid = l+r+1>>1;
        if(check(mid) >= k)l = mid;
        else r = mid-1;
    }
    check(l);
    cout << l*k+f[1][0].v;
    return 0;
}

练习:CF802O April Fools' Problem

wqs二分+反悔贪心

代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
#define ll long long
using namespace std;
const int N = 5e5+5;
int a[N],b[N],n,m;
ll sum;
struct node
{
    ll v;bool tp;
    friend bool operator < (node x,node y)
    {return x.v != y.v?x.v > y.v:x.tp < y.tp;}
};priority_queue<node> q;
int check(ll x)
{
    int cnt = 0;sum = 0;
    while(!q.empty())q.pop();
    for(int i = 1;i <= n;i++)
    {
        q.push({a[i]-x,1});
        node now = q.top();ll s = now.v+b[i];
        if(s < 0)
        {
            sum += s;cnt += now.tp;
            q.pop();q.push({-b[i],0});
        }
    }
    return cnt;
}
inline int rd()
{
    char c;int f = 1;
    while(!isdigit(c = getchar()))if(c=='-')f = -1;
    int x = c-'0';
    while(isdigit(c = getchar()))x = x*10+(c^48);
    return x*f;
}
int main()
{
    // freopen(".in","r",stdin);
    // freopen(".out","w",stdout);
    n = rd();m = rd();
    for(int i = 1;i <= n;i++)a[i] = rd();
    for(int i = 1;i <= n;i++)b[i] = rd();
    ll l = 0,r = 2e9;
    while(l < r)
    {
        ll mid = l+r>>1;
        if(check(mid) >= m)r = mid;
        else l = mid+1;
    }
    check(l);
    cout << l*m+sum << endl;
    return 0;
}

总结

在做决策单调性优化 DP 的题中,一般是先写出转移方程式,然后看转移式是哪种类型的,判断 w 符合哪些性质,然后再决定用哪种方法去做,如分治,枚举决策点范围,wqs 二分等等。总之,这类题还是要多练习才能熟练掌握。

posted @   max0810  阅读(52)  评论(2编辑  收藏  举报
相关博文:
阅读排行:
· 地球OL攻略 —— 某应届生求职总结
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 提示词工程——AI应用必不可少的技术
· .NET周刊【3月第1期 2025-03-02】
点击右上角即可分享
微信分享提示