【DP】区间DP入门

在开始之前我要感谢y总,是他精彩的讲解才让我对区间DP有较深的认识。

简介

一般是线性结构上的对区间进行求解最值,计数的动态规划。大致思路是枚举断点,然后对断点两边求取最优解,然后进行合并从而得解。

原理

结合模板题(合并石子)讲述:https://www.acwing.com/problem/content/284/

因为题目具有合并相邻物品的性质,所以在合并的过程中,必然会在最后一步出现两个物品合二为一的情况,而这两个物品则是分别由左侧的物品、右侧的物品合并而来的。 因此,我们的思路是枚举最后一步合并两个物品时候的断点(记为 \(k\) ),为了方便起见,我们可以将断点放在某个物品上面。

结合样例具体来说:

k
1 3 5 2
  k
1 3 5 2
    k
1 3 5 2
    k
1 3 5 2

上面便是四个断点。


对于本题,我们记f[l][r]为合并 \([l,r]\) 的物品所能得到的最小贡献。
而断点将 \([l,r]\) 分为了 \([l,k],[k+1,r]\) ,这两个区间的贡献分别是 f[l][k],f[k+1][r] 而合并这两个区间的贡献则是 sum(l,r)) (其中sum(l,r) 表示 \([l,r]\) 的物品的权值和)

从而得到递推方程式: f[l][r] = min(f[l][r],f[l][k]+f[k+1][r]+sum(l,r))

可以看出,在枚举断点的过程中,我们已经覆盖了所有情况(根据断点所有可能位置分类),因此这样做能够保证得到答案。

至此,在思维上不会有太大困难。

下面讲一下怎么用递推的方法求解:

由本题的逻辑结构可知,我们要先处理出小区间的 \(f值\) 才能够保证大区间可以得到更新,所以我们第一重循环枚举的是区间的长度len,下面的部分则是枚举起点(即 l), 结合长度我们可以得到 r = l+len-1 ,进而我们得到了相应的区间 \([l,r]\) ,接下来枚举断点 \(k\) 即可。

结合代码理解:

#include<bits/stdc++.h>
using namespace std;

const int INF=0x3f3f3f3f;
const int N=305;

int f[N][N];
int w[N],s[N];
int n;
int main(){
    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>w[i];
        s[i]=s[i-1]+w[i];
    }
    
    for(int len=1;len<=n;len++)
        for(int l=1;l+len-1<=n;l++){
            int r=l+len-1;
            if(len==1){
                f[l][r]=0;
            }else{
                f[l][r]=INF;
                for(int k=l;k<r;k++)
                    f[l][r]=min(f[l][r],f[l][k]+f[k+1][r]+s[r]-s[l-1]);
            }
        }
    cout<<f[1][n]<<endl;
    
    return 0;
}

当然,也可以采取记忆化搜索,这样不需要考虑太多。

例题

环形石子合并:https://www.acwing.com/activity/content/problem/content/1297/1/

分析

这题无非是将上题排成一列的物品放在了环上,因此我们可以采取断环成链的技巧:
显然,合并 \(n\) 个物品需要 \(n-1\) 步,因此,必然存在两个物品,它们并没有进行合并,那么它们之间便出现了“断边”,这样的“断边”并不会参与到合并的过程中,问题便由环转化为链的情况,所以我们只需枚举“断边”,然后进行求解即可。

有一个技巧:只需将原有的物品再按顺序“复制”一份,分别得到区间:

对于样例:

4 5 9 4

复制:

4 5 9 4 4 5 9 4

然后依次把区间(记为 \([s,t]\) )取出求解:

s     t
4 5 9 4 4 5 9 4
  s     t
4 5 9 4 4 5 9 4
    s     t
4 5 9 4 4 5 9 4
      s     t
4 5 9 4 4 5 9 4

(最后一个复制的元素是没用的,可以忽略)

这样分别求解四个子问题就行了。

代码:

#include<bits/stdc++.h>
using namespace std;

#define INF 0x3f3f3f3f

const int N = 410;

int f[N][N],g[N][N];
int s[N],w[N];
int n;

int main(){
    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>w[i];
        w[i+n]=w[i];
    }
    
    memset(f,0x3f,sizeof f);
    memset(g,0xcf,sizeof g);
    
    for(int i=1;i<=2*n;i++) s[i]=s[i-1]+w[i];
    
    for(int len=1;len<=n;len++){
        for(int l=1;l+len-1<=n*2;l++){
            int r=l+len-1;
            
            if(len==1) f[l][r]=g[l][r]=0;
            else{
                for(int k=l;k<r;k++){
                    f[l][r]=min(f[l][r],f[l][k]+f[k+1][r]+s[r]-s[l-1]);
                    g[l][r]=max(g[l][r],g[l][k]+g[k+1][r]+s[r]-s[l-1]);
                }
            }
                
        }
    }
    
    int maxv=-INF,minv=INF;
    for(int i=1;i<=n;i++){
        maxv=max(maxv,g[i][i+n-1]);
        minv=min(minv,f[i][i+n-1]);
    }
    
    cout<<minv<<endl<<maxv<<endl;
    
    return 0;
}

记忆化搜索版本:(比较久之前写的emm)

#include<bits/stdc++.h>
using namespace std;
#define maxn 101
int n;
int a[maxn<<1];
int f_max[maxn][maxn];
int f_min[maxn][maxn];
int rec[maxn];
int s[maxn];

int sum(int l,int r){
    return s[r]-s[l-1];
}

int dfs_max(int l,int r){
    if(l==r) return f_max[l][r]=0;
    if(f_max[l][r]) return f_max[l][r];

    int res=0;
    for(int k=l;k+1<=r;k++){
        res=max(res,dfs_max(l,k)+dfs_max(k+1,r)+sum(l,r));
    }
    return f_max[l][r]=res;
}

int dfs_min(int l,int r){
    if(l==r) return f_min[l][r]=0;
    if(f_min[l][r]) return f_min[l][r];

    int res=INT_MAX;
    for(int k=l;k+1<=r;k++){
        res=min(res,dfs_min(l,k)+dfs_min(k+1,r)+sum(l,r));
    }
    return f_min[l][r]=res;
}

int main(){
    cin>>n;
    for(int i=1;i<=n-1;i++) cin>>a[i],a[i+n]=a[i];
    cin>>a[n];

    int rec_max=0;
    int rec_min=INT_MAX;

    for(int st=1;st<=n;st++){
        memset(rec,0,sizeof(rec));
        memset(s,0,sizeof(s));
        memset(f_max,0,sizeof(f_max));
        memset(f_min,0,sizeof(f_min));
        for(int i=st;i<=st+n-1;i++) rec[i-st+1]=a[i];

        s[1]=rec[1];
        for(int i=2;i<=n;i++) s[i]=s[i-1]+rec[i];

        rec_max=max(rec_max,dfs_max(1,n));
        rec_min=min(rec_min,dfs_min(1,n));
    }
    cout<<rec_min<<endl;
    cout<<rec_max<<endl;
    return 0;
}

能量项链:https://www.acwing.com/problem/content/322/

分析
和上面题目类似(事实上区间DP的题都差不多),要注意理解是如何合并珠子的。

代码:

#include<bits/stdc++.h>
using namespace std;

const int N=105;

int n;
int w[N<<1];
int f[N<<1][N<<1];

int main(){
    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>w[i];
        w[n+i]=w[i];
    }
    
    for(int len=3;len<=n+1;len++)
        for(int l=1;l+len-1<=2*n;l++){
            int r=l+len-1;
            for(int k=l+1;k<=r-1;k++)
                f[l][r]=max(f[l][r],f[l][k]+f[k][r]+w[l]*w[k]*w[r]);
        }
        
    int res=0;
    for(int i=1;i<=n;i++) res=max(res,f[i][i+n]);
    
    cout<<res<<endl;
    
    return 0;
}

记忆化搜索版本:

#include<bits/stdc++.h>
using namespace std;

const int N=210;

int n;
int w[N];
int f[N][N];

int dp(int l,int r){
    if(f[l][r]>=0) return f[l][r];
    if(r==l || r==l+1) return f[l][r]=0;
    
    int &v=f[l][r];
    for(int k=l+1;k<=r-1;k++){
        v=max(v,dp(l,k)+dp(k,r)+w[l]*w[k]*w[r]);
    }
    return v;
}

int main(){
    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>w[i];
        w[n+i]=w[i];
    }
    
    memset(f,-1,sizeof f);
    
    int res=0;
    for(int i=1;i<=n;i++) res=max(res,dp(i,i+n));
    
    cout<<res<<endl;
    
    return 0;
}

加分二叉树:https://www.acwing.com/problem/content/481/

分析

g[l][r] 表示 \([l,r]\) 的根节点。
将中序遍历的序列看作是区间求解,然后枚举根节点(将它作为断点),记录答案的过程中要注意当答案得到更新的时候才记录这个区间的根节点。

#include<bits/stdc++.h>
using namespace std;

const int N=35;

int f[N][N]; //dp
int g[N][N]; //path

int n;
int w[N];

void dfs(int l,int r){
    if(l>r) return;
    
    int root=g[l][r];
    cout<<root<<' ';
    dfs(l,root-1);
    dfs(root+1,r);
}
int main(){
    cin>>n;
    for(int i=1;i<=n;i++) cin>>w[i];
    
    for(int len=1;len<=n;len++)
        for(int l=1;l+len-1<=n;l++){
            int r=l+len-1;
            if(len==1){
                f[l][r]=w[l];
                g[l][r]=l;
            }
            else{
                for(int k=l;k<=r;k++){
                    int left= k==l?1:f[l][k-1];
                    int right= k==r?1:f[k+1][r];
                    int score=left*right + w[k];
                    if(score>f[l][r]){
                        f[l][r]=score;
                        g[l][r]=k;
                    }
                }
            }
        }
    
    cout<<f[1][n]<<endl;
    dfs(1,n);
    
    return 0;
}
posted @ 2021-02-15 10:25  HinanawiTenshi  阅读(429)  评论(0编辑  收藏  举报