区间 dp

模板区间 dp

  • 一个长 \(n(n \le 248)\) 的序列,选择数列中两个相邻且相等的元素,删去其中一个元素并使另一个元素的值 \(+1\),求数次操作后数列中的最大值
  • 将这看做是合并的过程
  • \(dp_{i, j}\) 表示区间 \([i, j]\) 和为一个答案的取值,这里的取值其实是唯一的,所以可以之间判断
  • 对于每个 \(dp_{i, j}\) 找到一个合法的 \(mid(i \le mid < r)\),使得 \(dp_{i, m} = dp_{m + 1, j}\),那么 \(dp_{i, j} = dp_{i, m} + 1\)
点击查看 AC 代码
#include <bits/stdc++.h>

using namespace std;
using LL = long long;
using PII = pair<int, int>;

const int MAXN = 300 + 3, MAXX = 502;

int n;
int a[MAXN];
int dp[MAXN][MAXN];

int main() {
  ios::sync_with_stdio(0), cin.tie(0);
  cin >> n;
  int ANS = 0;
  memset(dp, -1, sizeof(dp));
  for(int i = 1; i <= n; i++){
    cin >> a[i], ANS = max(ANS, a[i]), dp[i][i] = a[i];
  }
  for(int l = 2; l <= n; l++){
    for(int i = n - l + 1; i >= 1; i--){
      int j = i + l - 1;
      for(int m = i; m < j; m++){
        if(dp[i][m] != -1 && dp[m + 1][j] != -1 && dp[i][m] == dp[m + 1][j]){
          dp[i][j] = dp[i][m] + 1;
          ANS = max(ANS, dp[i][m] + 1);
        }
      }
    }
  }
  cout << ANS;
  return 0;
}

dp 的转移会有多余、重复的操作

例题 1

  • 对于这一题,\(dp_{i, j}\) 有两种转移
  • 一种可以直接 \(dp_{i, j} = dp_{i, m} + dp_{m+1, j}\)
  • 另一种可能会出现合并操作,那么你其实只需要判断 \(a_i\)\(a_j\) 是否有相同
  • 如果那个合并的中间,那么只有可能与 \(a_i\)\(a_j\) 有关,不然就一定存在更优的转移
点击查看 AC 代码
#include <bits/stdc++.h>

using namespace std;
using LL = long long;
using PII = pair<int, int>;

const int MAXN = 300 + 3;

int n, m;
int a[MAXN], b[MAXN];
int dp[MAXN][MAXN];

int main() {
  cin >> n, m = 0;
  for(int i = 0; i <= n; i++){
    for(int j = 0; j <= n; j++){
      for(int x = 0; x <= n; x++) dp[i][j] = 1e9;
    }
  }
  for(int i = 1; i <= n; i++){
    cin >> a[i];
  }
  for(int i = 1; i <= n; i++){
    if(i == 1 || a[i] != a[i - 1]){
      b[++m] = a[i];
    }
  }
  for(int i = 1; i <= m; i++){
    dp[i][i] = 1;
  }
  for(int l = 2; l <= m; l++){
    for(int i = 1; i <= m - l + 1; i++){
      int j = i + l - 1;
      for(int m = i; m < j; m++){
        dp[i][j] = min(dp[i][j], dp[i][m] + dp[m + 1][j] - (b[i] == b[j]));
      }
    }
  }
  cout << dp[1][m];
  return 0;
}

例题 2

  • 可以将回文串拆分一下,那么就是 \(dp_{i, j}\) 可以拆分为 \(s + dp_{x, y} + m + rs\)
  • 其中 \(m\) 是回文串,\(rs\)\(s\) 翻转过后得到的串
  • \(O(n^4)\) 暴力便出来了:预先 \(O(n^2)\) 找出所有回文串,再 \(O(n^4)\) 区间 dp(枚举状态 \(O(n^2)\),转移 \(O(n^2)\)
点击查看 $O(n^4)$ 代码
#include <bits/stdc++.h>

using namespace std;
using LL = long long;
using PII = pair<int, int>;

const int MAXN = 500 + 3;

int n;
int a[MAXN];
int f[MAXN][MAXN];
int dp[MAXN][MAXN];

int main() {
  ios::sync_with_stdio(0), cin.tie(0);
  //freopen("temp5___.out", "w", stdout);
  cin >> n;
  for(int i = 1; i <= n; i++){
    cin >> a[i];
  }
  for(int i = 0; i <= n; i++){
    for(int j = 0; j <= n; j++) dp[i][j] = 1e9;
  }
  for(int i = 1; i <= n; i++){
    for(int j = i; j <= n; j++){
      int _j = i - (j - i);
      if(_j < 1 || a[j] != a[_j]) break;
      f[_j][j] = 1, dp[_j][j] = 1;
    }
    for(int j = i + 1; j <= n; j++){
      int _j = i - (j - i - 1);
      if(_j < 1 || a[j] != a[_j]) break;
      f[_j][j] = 1, dp[_j][j] = 1;
    }
  }
  for(int i = 1; i <= n; i++) f[i][i] = 1, dp[i][i] = 1;
  for(int l = 2; l <= n; l++){
    for(int i = 1; i <= n - l + 1; i++){
      int j = i + l - 1;
      for(int m = i; m <= j; m++){
        int _m = j - (m - i);
        if(m > _m) break;
        for(int x = m; x <= _m; x++){
          if(f[x + 1][_m]) dp[i][j] = min(dp[i][j], dp[m][x] + 1);
        }
        for(int x = m; x <= _m; x++){
          if(f[m][x - 1]) dp[i][j] = min(dp[i][j], dp[x][_m] + 1);
        }
        if(a[m] != a[_m]) break;
      }
      for(int m = i; m < j; m++){
        dp[i][j] = min(dp[i][j], dp[i][m] + dp[m + 1][j]);
      }
    }
  }
  cout << dp[1][n];
  return 0;
}
  • 现在的时间复杂度瓶颈就是枚举字符串 \(m\) 的时间
  • 发现枚举字符串 \(m\) 其实是多余的!!!
  • 如果 \(m\) 是回文的,那么必定有 \(dp_{x, z} = dp_{x, y} + m\)
  • \(dp_{i, j}\) 可以直接转移到 \(dp_{x, z}\)
  • 暴力转移就可以了。
点击查看 AC 代码
#include <algorithm>
#include <iostream>

using namespace std;
using LL = long long;
using PII = pair<int, int>;

const int MAXN = 500 + 3;

int n;
int a[MAXN];
int f[MAXN][MAXN];
int dp[MAXN][MAXN];

int main() {
  ios::sync_with_stdio(0), cin.tie(0);
  cin >> n;
  for(int i = 1; i <= n; i++){
    cin >> a[i];
  }
  for(int i = 0; i <= n; i++){
    for(int j = 0; j <= n; j++) dp[i][j] = 3e8;
  }
  for(int i = 1; i <= n; i++){
    for(int j = i; j <= n; j++){
      int _j = i - (j - i);
      if(_j < 1 || a[j] != a[_j]) break;
      f[_j][j] = 1, dp[_j][j] = 1;
    }
    for(int j = i + 1; j <= n; j++){
      int _j = i - (j - i - 1);
      if(_j < 1 || a[j] != a[_j]) break;
      f[_j][j] = 1, dp[_j][j] = 1;
    }
  }
  for(int i = 1; i <= n; i++) f[i][i] = 1, dp[i][i] = 1;
  for(int l = 1; l <= n; l++){
    for(int i = 1; i <= n - l + 1; i++){
      int j = i + l - 1;
      for(int m = i; m <= j; m++){
        int _m = j - (m - i);
        if(m > _m) break;
        if(m > i) dp[i][j] = min(dp[i][j], dp[m][_m]);
        if(a[m] != a[_m]) break;
      }
      for(int m = i; m < j; m++){
        dp[i][j] = min(dp[i][j], dp[i][m] + dp[m + 1][j]);
      }
    }
  }
  cout << dp[1][n];
  return 0;
}

树上区间 dp

  • 树上的每一颗子树就是一段区间
  • 这一题是求总代价最小值
  • 可以将总贡献分散开,求对于每一条边对答案的总贡献
  • 树上,一条边就可以将一棵树分为两个连通块
  • 所以一条边的贡献,就是 一个连通块 与 另一个连通块 两两配对求和
  • 题目要求构造一颗二叉树,可以暴力计算,然后前缀和优化
点击查看 $O(n^4)$ 暴力代码(这题里好像可以过...)
#include <algorithm>
#include <iostream>

using namespace std;
using LL = long long;
using PII = pair<int, int>;

const int MAXN = 200 + 3;
const LL Inf = 2e17;

struct DP{
  LL w;
  int root;
  int fa[MAXN];
};

int n;
LL a[MAXN][MAXN], sum[MAXN][MAXN];
DP dp[MAXN][MAXN];

int main() {
  cin >> n;
  for(int i = 1; i <= n; i++){
    for(int j = 1; j <= n; j++){
      cin >> a[i][j];
      sum[i][j] = sum[i][j - 1] + a[i][j];
    }
  }
  for(int i = 1; i <= n; i++){
    for(int j = i; j <= n; j++){
      dp[i][j].w = Inf, dp[i][j].root = -1;
      for(int x = 1; x <= n; x++) dp[i][j].fa[x] = -1;
    }
  }
  for(int i = 1; i <= n; i++){
    dp[i][i].w = 0, dp[i][i].root = i, dp[i][i].fa[i] = 0;
  }
  for(int l = 2; l <= n; l++){
    for(int i = 1; i <= n - l + 1; i++){
      int j = i + l - 1;
      for(int m = i; m <= j; m++){
        LL _w = (i < m ? dp[i][m - 1].w : 0) + (m < j ? dp[m + 1][j].w : 0);
        for(int x = i; x < m; x++){
           _w += (sum[x][i - 1] + sum[x][n] - sum[x][m - 1]);
        }
        for(int x = m + 1; x <= j; x++){
          _w += (sum[x][m] + sum[x][n] - sum[x][j]);
        }
        if(_w < dp[i][j].w){
          dp[i][j].w = _w;
          dp[i][j].root = m;
          for(int x = i; x < m; x++){
            dp[i][j].fa[x] = dp[i][m - 1].fa[x];
          }
          for(int x = m + 1; x <= j; x++){
            dp[i][j].fa[x] = dp[m + 1][j].fa[x];
          }
          dp[i][j].fa[m] = 0;
          if(i < m) dp[i][j].fa[dp[i][m - 1].root] = m;
          if(m < j) dp[i][j].fa[dp[m + 1][j].root] = m;
        }
      }
    }
  }
  for(int i = 1; i <= n; i++){
    cout << dp[1][n].fa[i] << " ";
  }
  return 0;
}
点击查看前缀和优化后的 $O(n^3)$ 代码
#include <algorithm>
#include <iostream>

#pragma GCC optimize(fast)

using namespace std;
using LL = long long;
using PII = pair<int, int>;

const int MAXN = 200 + 3;
const LL Inf = 2e17;

struct DP{
  LL w;
  int root;
}dp[MAXN][MAXN];

int n, ans[MAXN];
LL a[MAXN][MAXN], sum[MAXN][MAXN];
LL ssum[MAXN][MAXN];

void dfs(int l, int r, int dad){
  if(l > r) return;
  ans[dp[l][r].root] = dad;
  dfs(l, dp[l][r].root - 1, dp[l][r].root);
  dfs(dp[l][r].root + 1, r, dp[l][r].root);
}

int main(){
  ios::sync_with_stdio(0), cin.tie(0);
  cin >> n;
  for(int i = 1; i <= n; i++){
    for(int j = 1; j <= n; j++){
      cin >> a[i][j];
      sum[i][j] = sum[i][j - 1] + a[i][j];
      ssum[i][j] = ssum[i - 1][j] + sum[i][j];
      dp[i][j].w = Inf, dp[i][j].root = -1;
    }
  }
  for(int i = 1; i <= n; i++) dp[i][i].w = 0, dp[i][i].root = i;
  for(int l = 2; l <= n; l++){
    for(int i = 1; i <= n - l + 1; i++){
      int j = i + l - 1;
      for(int m = i; m <= j; m++){
        LL _w = (i < m ? dp[i][m - 1].w : 0) + (m < j ? dp[m + 1][j].w : 0);
        _w += (ssum[m-1][n] - ssum[i-1][n]) + (ssum[m-1][i-1] - ssum[i-1][i-1]) - (ssum[m-1][m-1] - ssum[i-1][m-1]);
        _w += (ssum[j][n] - ssum[m][n]) + (ssum[j][m] - ssum[m][m]) - (ssum[j][j] - ssum[m][j]);
        if(_w < dp[i][j].w) dp[i][j].w = _w, dp[i][j].root = m;
      }
    }
  }
  dfs(1, n, 0);
  for(int i = 1; i <= n; i++){
    cout << ans[i] << " ";
  }
  return 0;
}
posted @ 2023-08-10 20:46  hhhqx  阅读(9)  评论(0编辑  收藏  举报