区间 DP、环形 DP

区间 DP

区间 DP 是可以由小区间的结果往两边扩展一位得到大区间的结果,或者由两个小区间的结果可以拼出大区间的结果的一类 DP 问题

往往设 \(dp[i][j]\) 表示处理完 \([i,j]\) 区间得到的答案,按长度从小到大转移

因此一般是先写一层循环从小到大枚举长度 \(len\),再写一层循环枚举左端点 \(i\),算出右端点 \(j\),然后写 \(dp[i][j]\) 的状态转移方程

区间 DP 也可以由大区间推到小区间,此时可以从大到小枚举 \(len\)

区间 DP 的提示信息:

  1. 从两端取出或在两端插入,这就是大区间变到小区间或者小区间变到大区间
  2. 合并相邻的,这样的一步相当于把已经处理好的两个小区间得到的结果合并为当前大区间的结果
  3. 消去连续一段使两边接起来,可以枚举最后一次消哪个区间,这样就可以把大区间拆成小区间
  4. 两个东西可以配对消掉,这时往往可以按左端点和哪个东西配对,把当前区间拆成两个子区间的问题
  5. 时间复杂度通常为 \(O(n^2)\)\(O(n^3)\)

例:P2858 [USACO06FEB] Treats for the Cows G/S

解题思路

考虑操作过程,以第一步为例,你会把 \([1,n]\) 通过拿走最左边一个或最右边一个变为 \([2,n]\)\([1,n-1]\),这就是区间的变化

我们可以考虑最后一次拿零食,此时一定是只剩一件零食了,这就是长度为 \(1\) 的区间,由于它一定是最后一天出售,此时它的售价为 \(n*v[i]\)

由此我们设计 \(dp[i][j]\) 表示卖光 \([i,j]\) 区间内的零食的最大售价

那么状态转移方程就是 \(dp[i][j] = \max (dp[i+1][j] + v[i] * (n-(j-i)), dp[i][j-1]+v[j]*(n-(j-i)))\)

初始化 \(dp[i][i]=v[i]*n\),从小区间往大区间推,最后答案为 \(dp[1][n]\)

时间复杂度 \(O(n^2)\)

参考代码
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 2005;
int v[N], dp[N][N];
int main()
{
    int n;
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &v[i]);
        dp[i][i] = v[i] * n;
    }
    for (int len = 2; len <= n; len++) {
        for (int i = 1; i <= n - len + 1; i++) {
            int j = i + len - 1, a = n - len + 1;
            dp[i][j] = max(dp[i + 1][j] + v[i] * a, dp[i][j - 1] + v[j] * a);
        }
    }
    printf("%d\n", dp[1][n]);
    return 0;
}

例:P3205 [HNOI2010] 合唱队

解题思路

每次插入到队伍最左边或最右边,也就是说如果 \([i,j]\) 排好了,接下来一个人插到左边就排好了 \([i-1,j]\) 区间,如果插到右边就排好了 \([i,j+1]\) 区间,这就是小区间推到大区间

但是能不能插进来还要看这次加入的数和上一次插入的数是否符合对应的大小关系,因此我们还需要知道插入的最后一个数是最左边的还是最右边的

可以设 \(dp_{i,j,0/1}\) 表示把 \([i,j]\) 区间排好且最后一个人是在左边/右边的方案数,初始化 \(dp_{i,i,0}=1\),即只有一个人的时候强制认为它是插入在左边

考虑转移,对于 \(dp_{i,j,0}\),此时就是看 \(h_i\) 插进来的时候能否符合题目中的条件,此时需要知道 \([i+1,j]\) 中最后插进来的是哪个,如果是 \(h_{i+1}\),并且 \(h_i \lt h_{i+1}\),那么 \(dp_{i,j,0}\) 加上 \(dp_{i+1,j,0}\),如果是 \(h_j\),并且 \(h_i \lt h_j\),那么 \(dp_{i,j,0}\) 加上 \(dp_{i+1,j,1}\)

对于 \(dp_{i,j,1}\),此时就是看 \(h_j\) 插进来的时候能否符合题目中的条件,此时需要知道 \([i,j-1]\) 中最后插进来的是哪个,如果是 \(h_i\),并且 \(h_j \gt h_i\),那么 \(dp_{i,j,1}\) 加上 \(dp_{i,j-1,0}\),如果是 \(h_{j-1}\),并且 \(h_j \gt h_{j-1}\),那么 \(dp_{i,j,1}\) 加上 \(dp_{i,j-1,1}\)

时间复杂度 \(O(n^2)\)

参考代码
#include <cstdio>
const int N = 1005;
const int MOD = 19650827;
int dp[N][N][2], h[N];
int main()
{
	int n;
	scanf("%d", &n);
	for (int i = 1; i <= n; i++) {
		scanf("%d", &h[i]);
		dp[i][i][0] = 1;
	}
	for (int len = 2; len <= n; len++) {
		for (int i = 1; i <= n - len + 1; i++) {
			int j = i + len - 1;
			// [i,j] from [i+1,j] [i,j-1]
			if (h[i] < h[i+1]) dp[i][j][0] = (dp[i][j][0] + dp[i+1][j][0]) % MOD;
			if (h[i] < h[j]) dp[i][j][0] = (dp[i][j][0] + dp[i+1][j][1]) % MOD;
			if (h[j] > h[i]) dp[i][j][1] = (dp[i][j][1] + dp[i][j-1][0]) % MOD;
			if (h[j] > h[j-1]) dp[i][j][1] = (dp[i][j][1] + dp[i][j-1][1]) % MOD;
		}
	}
	printf("%d\n", (dp[1][n][0] + dp[1][n][1]) % MOD);
	return 0;
}

例:P3146 [USACO16OPEN] 248 G

这个问题和之前扩展一位的问题略有不同,这是由两个区间的结果合并推到更大区间的结果

解题思路

可以设 \(dp_{i,j}\) 表示把 \([i,j]\) 合并得到的最大数,如果这一段无法合并成一个数,则 \(dp\) 值为 \(0\)

初始化 \(dp_{i,i}=a_i\)

考虑转移,对于区间 \([i,j]\),我们需要枚举分界点 \(k\),将 \([i,j]\) 拆成 \([i,k]\)\([k+1,j]\) 这两部分,先让 \([i,k]\) 合成一个数,\([k+1,j]\) 合成一个数,再让这两个数合并

这样就可以写出状态转移方程:如果 \(dp_{i,k}\)\(dp_{k+1,j}\) 相等且非 \(0\),则 \(dp_{i,j}= \max (dp_{i,j}, dp_{i,k}+1)\),最后答案为 \(\max \{dp_{i,j}\}\)

时间复杂度 \(O(n^3)\)

参考代码
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 250;
int dp[N][N];
int main()
{
    int n, ans = 0;
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &dp[i][i]);
        ans = max(ans, dp[i][i]);
    }
    for (int len = 2; len <= n; len++) {
        for (int i = 1; i <= n - len + 1; i++) {
            int j = i + len - 1;
            for (int k = i; k < j; k++) 
                if (dp[i][k] == dp[k + 1][j] && dp[i][k]) 
                    dp[i][j] = max(dp[i][j], dp[i][k] + 1);
            ans = max(ans, dp[i][j]);
        }
    }
    printf("%d\n", ans);
    return 0;
}

例:P4170 [CQOI2007] 涂色

解题思路

考虑染色过程,因为是一段一段染的,长段可以看成是两个短段拼起来,并且如果某一次染了一段之后,可以在这段内部继续染色,这都提示我们可以考虑区间 DP

\(dp_{i,j}\) 表示染完 \(i\)\(j\) 的最少次数,初始化 \(dp_{i,i}=1\),一段拆成两段染,则有 \(dp_{i,j} = \min \{dp_{i,k}+dp_{k+1,j}\}\)

特殊情况:如果 \(s_i = s_j\),可以在染完 \([i+1,j]\)\([i,j-1]\) 的时候顺带把 \(i\)\(j\) 染了,这样的结果一定优于拆两段,此时 \(dp_{i,j}= \min (dp_{i,j-1}, dp_{i+1,j})\)

分析:拆段意味着存在两步分别染 \([i,x_1]\)\([x_2,j]\),而不拆段则可以直接改成在第一步染一次 \([i,j]\),染这一次也不会干扰到后续染色过程,因为后续染色是直接覆盖中间的某段区域,所以之前被这次 \([i,j]\) 染过也没有关系

对于其他题目,有可能出现在可以不拆段时依然是拆段取到最优解的情况,注意分析,如果想简化分析过程可以统一枚举拆段转移最优解的过程,因为在可能需要拆段的题目中这样做不会影响时间复杂度

时间复杂度 \(O(n^3)\)

参考代码
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 55;
char s[N];
int dp[N][N];
int main()
{
    scanf("%s", s + 1);
    int n = strlen(s + 1);
    for (int i = 1; i <= n; i++) dp[i][i] = 1;
    for (int len = 2; len <= n; len++) {
        for (int i = 1; i <= n - len + 1; i++) {
            int j = i + len - 1;
            dp[i][j] = min(dp[i + 1][j], dp[i][j - 1]) + (s[i] != s[j]);
            for (int k = i; k < j; k++) dp[i][j] = min(dp[i][j], dp[i][k] + dp[k + 1][j]);
        }
    }
    printf("%d\n", dp[1][n]);
    return 0;
}

例:CF607B Zuma

解题思路

\(dp_{i,j}\) 为移除 \([i,j]\) 的最短时间,则有初始化

\(\begin{cases} dp_{i,i}=1 & \\ dp_{i,i+1}=1 & c_i=c_{i+1} \\ dp_{i,i+1}=2 & c_i \ne c_{i+1} \end{cases}\)

状态转移方程

\(\begin{cases} dp_{i,j}=dp_{i+1,j-1} & c_i=c_j \\ dp_{i,j}=\min \{ dp_{i,k}+dp_{k+1,j} \} \end{cases}\)

注意:就算是 \(c_i=c_j\),也有可能是拆段更优,比如 \([1, 2, 1, 1, 3, 1]\)

因此无论 \(c_i\) 是否等于 \(c_j\),都必须做拆段的这种转移,时间复杂度 \(O(n^3)\)

参考代码
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 505;
int c[N], dp[N][N];
int main()
{
    int n;
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &c[i]);
        dp[i][i] = 1; 
    }
    for (int len = 2; len <= n; len++) {
        for (int i = 1; i <= n - len + 1; i++) {
            int j = i + len - 1;
            dp[i][j] = min(dp[i + 1][j], dp[i][j - 1]) + 1;
            if (c[i] == c[j]) dp[i][j] = min(dp[i][j], len == 2 ? 1 : dp[i + 1][j - 1]);
            for (int k = i; k < j; k++) dp[i][j] = min(dp[i][j], dp[i][k] + dp[k + 1][j]);
        }
    }
    printf("%d\n", dp[1][n]);
    return 0;
}

例:CF149D Coloring Brackets

解题思路

\(dp[l][r][cl][cr]\) 代表对 \([l,r]\) 区间染色且端点 \(l\) 处颜色为 \(cl\),端点 \(r\) 处颜色为 \(cr\) 情况下的染色方案数,这里颜色的取值可以设在 \(0 \sim 2\) 之间,如 \(0\) 代表不染色,\(1\) 代表染红色,\(2\) 代表染蓝色

\(l\) 位置的括号和 \(r\) 位置的括号形成匹配关系时,此时可以将问题转化为先去计算 \(dp[l+1][r-1]\) 的值,再处理匹配括号、相邻扩号的约束关系

\(l\) 位置的括号和 \(r\) 位置的括号不构成匹配关系时,此时应当将问题转化为两段独立形成括号匹配的串的染色方案合并起来的结果,为了方便拆分,可以预处理出原串中每个括号的对应匹配关系

以上两种转移方式要想组合到一起不容易用循环来实现,实际上这里更容易写的方式是利用递归回溯的过程完成计算

参考代码
#include <cstdio>
#include <cstring>
#include <stack>
using namespace std;
const int N = 705;
const int MOD = 1000000007;
char s[N];
// 0: none, 1: red, 2: blue
int dp[N][N][3][3], match[N];
void dfs(int l, int r) {
    if (l >= r) return;
    if (l + 1 == r) {
        // 会执行到这个位置只能是 ()
        dp[l][r][0][1] = dp[l][r][0][2] = 1;
        dp[l][r][1][0] = dp[l][r][2][0] = 1;
        return;
    }
    if (match[l] != r) {
        dfs(l, match[l]);
        dfs(match[l] + 1, r);
        for (int c1 = 0; c1 < 3; c1++) {
            for (int c2 = 0; c2 < 3; c2++) {
                for (int c3 = 0; c3 < 3; c3++) {
                    if (c2 == c3 && c2 != 0) continue;
                    for (int c4 = 0; c4 < 3; c4++) {
                        int left = dp[l][match[l]][c1][c2];
                        int right = dp[match[l] + 1][r][c3][c4];
                        dp[l][r][c1][c4] += 1ll * left * right % MOD;
                        dp[l][r][c1][c4] %= MOD;
                    }
                }  
            }               
        }          
    } else {
        dfs(l + 1, r - 1);
        for (int c1 = 0; c1 < 3; c1++) {
            for (int c2 = 0; c2 < 3; c2++) {
                if (c1 == c2 && c1 != 0) continue;
                for (int c3 = 0; c3 < 3; c3++) {
                    for (int c4 = 0; c4 < 3; c4++) {
                        if (c3 == c4 && c3 != 0) continue;
                        if (c1 == 0 && c4 == 0) continue;
                        if (c1 != 0 && c4 != 0) continue;
                        dp[l][r][c1][c4] += dp[l + 1][r - 1][c2][c3];
                        dp[l][r][c1][c4] %= MOD;
                    }
                }
            }
        }
    }
}
int main()
{
    scanf("%s", s + 1);
    int n = strlen(s + 1);
    stack<int> stk;
    for (int i = 1; i <= n; i++) {
        dp[i][i][0][0] = dp[i][i][1][1] = dp[i][i][2][2] = 1;
        if (s[i] == '(') {
            stk.push(i);
        } else if (s[i] == ')') {
            int t = stk.top(); stk.pop();
            match[i] = t; match[t] = i;
        }
    }
    dfs(1, n);
    int ans = 0;
    for (int i = 0; i < 3; i++)
        for (int j = 0; j < 3; j++)
            ans = (ans + dp[1][n][i][j]) % MOD;
    printf("%d\n", ans);
    return 0;
}

例:P7914 [CSP-S 2021] 括号序列

解题思路

我们设符合要求的序列成为 \(A\) 型序列,连续不超过 \(k\)\(*\) 的序列为 \(S\) 型序列

根据题意,符合要求的括号序列可以分成 \((),(A),(S),(AS),(SA),AA,ASA\)

对于 \((),(A),(S)\),区间 \([i,j]\) 的结果可以由区间 \([i+1,j-1]\) 的结果转移而来

对于 \((AS),(SA),AA,ASA\),需要枚举断点,把区间 \([i,j]\) 拆成两个子区间,对于 \(ASA\),可以拆成 \(A\)\(SA\) 两部分,因此同时需要维护 \(SA\) 型序列的方案数

然而,对于 \(AA\) 型序列,如 \(()()()\)\([1,2]\) 符合 \(A\) 型序列,\([3,6]\) 符合 \(A\) 型序列,而 \([1,4]\) 符合 \(A\) 型序列,\([5,6]\) 符合 \(A\) 型序列,枚举断点时产生了重复计算;同理对于 \(ASA\) 型序列,如 \(()*()*()\),也会产生重复计算

解决方法:计算 \(AA\) 型序列方案数时,将其拆分为 \(A'A\),其中 \(A\) 型序列指的是所有符合要求的括号序列,而 \(A'\) 型序列指的是最外层为一对匹配括号条件下的 \(A\) 型序列,即这种序列不是由 \(AA\)\(ASA\) 拼接得到的合法括号序列,这样枚举断点时不会导致重复计算;同理 \(ASA\) 型序列的计算可以拆成 \(A'\) 型序列的计算结果和 \(SA\) 型序列的计算结果的组合

参考代码
#include <cstdio>
const int N = 505;
const int MOD = 1000000007;
char s[N];
int dp_a[N][N], dp_s[N][N], dp_ba[N][N], dp_sa[N][N];
bool check(int idx, char ch) {
    return s[idx] == '?' || s[idx] == ch;
}
int main()
{
    int n, k;
    scanf("%d%d%s", &n, &k, s + 1);
    for (int i = 1; i <= n; i++)
        if (check(i, '*')) dp_s[i][i] = 1;
    if (k >= 2) {
        for (int i = 1; i < n; i++)
            if (check(i, '*') && check(i + 1, '*')) dp_s[i][i + 1] = 1;
        for (int len = 3; len <= k; len++) {
            for (int i = 1; i <= n - len + 1; i++) {
                int j = i + len - 1;
                if (check(i, '*') && check(j, '*')) 
                    dp_s[i][j] = dp_s[i + 1][j - 1];
            }
        }
    }
    for (int i = 1; i < n; i++) {
        if (check(i, '(') && check(i + 1, ')')) 
            dp_ba[i][i + 1] = dp_a[i][i + 1] = 1;
    }
    for (int len = 3; len <= n; len++) {
        for (int i = 1; i <= n - len + 1; i++) {
            int j = i + len - 1; 
            if (check(i, '(') && check(j, ')')) {
                // (A)
                dp_ba[i][j] += dp_a[i + 1][j - 1];
                dp_ba[i][j] %= MOD;
                // (S)
                dp_ba[i][j] += dp_s[i + 1][j - 1];
                dp_ba[i][j] %= MOD;
                // (AS) (SA)
                for (int k = i + 1; k < j - 1; k++) {
                    dp_ba[i][j] += 1ll * dp_a[i + 1][k] * dp_s[k + 1][j - 1] % MOD;
                    dp_ba[i][j] %= MOD;
                    dp_ba[i][j] += 1ll * dp_s[i + 1][k] * dp_a[k + 1][j - 1] % MOD;
                    dp_ba[i][j] %= MOD;
                }
                dp_a[i][j] = dp_ba[i][j];
            }   
            // AA ASA
            for (int k = i; k < j; k++) {
                dp_a[i][j] += 1ll * dp_ba[i][k] * dp_a[k + 1][j] % MOD;
                dp_a[i][j] %= MOD;
                dp_a[i][j] += 1ll * dp_ba[i][k] * dp_sa[k + 1][j] % MOD;
                dp_a[i][j] %= MOD;
                dp_sa[i][j] += 1ll * dp_s[i][k] * dp_a[k + 1][j] % MOD;
                dp_sa[i][j] %= MOD;
            }
        }
    }
    printf("%d\n", dp_a[1][n]);
    return 0;
}

环形 DP

有时我们会面临输入是环形数组的情况,即认为 \(a[n]\)\(a[1]\) 是相邻的,此时应该如何处理?

  • 如果是线性 DP,比如选数问题,可以对 \(a[n]\) 是否选进行分类,假设 \(a[n]\) 不选,把 \(dp[1][\dots]\) 初始化好,最后推到 \(dp[n][\dots]\) 时,只留下 \(a[n]\) 不选的情况计入答案;再假设 \(a[n]\) 选,把 \(dp[1][\dots]\) 初始化好,最后推到 \(dp[n][\dots]\) 时,只留下 \(a[n]\) 选的情况计入答案

  • 如果是区间 DP,一种常见的方法是破环成链,将数组复制一倍接在原数组之后,然后对这个长度为 \(2n\) 的数组进行长度不超过 \(n\) 的区间 DP

    \(dp[i][j]\)\(i \le n\)\(j>n\) 的情况就是把 \(a[n]\)\(a[1]\) 也看成了相邻的,比如 \(dp[2][n+1]\),就代表原数组 \(a[2],\dots,a[n],a[1]\) 这样一个环上的结果

例:P1880 [NOI1995] 石子合并

解题思路

破环成链后,对产生的长度为 \(2n\) 的数组中区间长度 \(\le n\) 的所有区间进行 DP

\(dpmax_{i,j}\) 表示将 \([i,j]\) 区间合并的最大得分,则有 \(dpmax_{i,j} = \max \{ dpmax_{i,k} + dpmax_{k+1,j} + (a_i + \dots + a_j) \}\)

\(dpmin_{i,j}\) 表示将 \([i,j]\) 区间合并的最小得分,则有 \(dpmin_{i,j} = \min \{ dpmin_{i,k} + dpmin_{k+1,j} + (a_i + \dots + a_j) \}\)

其中 \(a_i + \dots + a_j\) 可以利用前缀和预处理 \(O(1)\) 得到,总时间复杂度 \(O(n^3)\),最后答案为 \(dpmax_{i,i+n-1}\) 中的最大值和 \(dpmin_{i,i+n-1}\) 中的最小值

参考代码
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 205;
const int INF = 1e9;
int a[N], sum[N];
int dp1[N][N];
int dp2[N][N];
int main()
{
    int n;
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
        sum[i] = sum[i - 1] + a[i];
        a[i + n] = a[i];
    }
    for (int i = n + 1; i <= 2 * n; i++) sum[i] = sum[i - 1] + a[i];
    for (int i = 1; i < 2 * n; i++)
        for (int j = 1; j < 2 * n; j++) {
            dp1[i][j] = INF;
            dp2[i][j] = 0;
        }
    for (int i = 1; i < 2 * n; i++) dp1[i][i] = 0;
    for (int len = 2; len <= n; len++) {
        for (int i = 1; i <= 2 * n - len; i++) {
            int j = i + len - 1;
            for (int k = i; k < j; k++) {
				int total = sum[j] - sum[i - 1];
                dp1[i][j] = min(dp1[i][j], dp1[i][k] + dp1[k + 1][j] + total);
                dp2[i][j] = max(dp2[i][j], dp2[i][k] + dp2[k + 1][j] + total);
            }
        }
    }
    int ans1 = INF, ans2 = 0;
    for (int i = 1; i <= n; i++) {
        ans1 = min(ans1, dp1[i][i + n - 1]);
        ans2 = max(ans2, dp2[i][i + n - 1]);
    }
    printf("%d\n%d\n", ans1, ans2);
    return 0;
}

例:P1063 [NOIP2006 提高组] 能量项链

解题思路

与石子合并基本相同,先破环成链,然后设 \(dp_{i,j}\) 表示把 \([i,j]\) 的能量珠合成一个时释放的最大总能量

则有 \(dp_{i,j} = \max \{ dp_{i,k} + dp_{k+1,j} + a_i * a_{k+1} * a_{j+1} \}\)

需要注意要乘的三个数是哪三个,尤其是最后一个数,因为需要 \(a_{j+1}\),所以可以在刚开始复制 \(a\) 数组时多复制一位,让 \(a_{2*n+1}\) 也等于 \(a_1\)

最后答案为 \(\max \{ dp_{i,i+n-1} \}\),时间复杂度 \(O(n^3)\)

参考代码
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 205;
int a[N], dp[N][N];
int main()
{
    int n;
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
        a[i + n] = a[i];
    }
    a[2 * n + 1] = a[1];
    for (int len = 2; len <= n; len++) {
        for (int i = 1; i <= 2 * n - len; i++) {
            int j = i + len - 1;
            for (int k = i; k < j; k++) {
				int energy = a[i] * a[k + 1] * a[j + 1];
                dp[i][j] = max(dp[i][j], dp[i][k] + dp[k + 1][j] + energy);
            }
        }
    }
    int ans = 0;
    for (int i = 1; i <= n; i++) ans = max(ans, dp[i][i + n - 1]);
    printf("%d\n", ans);
    return 0;
}

例:P9119 [春季测试 2023] 圣诞树

前 6 个测试点(前 30 分),直接枚举全排列即可
前 12 个测试点(前 60 分),哈密顿路问题(后续在状态压缩 DP 中会讲)
特殊性质 B(额外 10 分),答案为 \(1, 2, 3, \dots, n\)

解题思路

首先要分析出一个重要结论

考虑将该凸多边形按最高点分为两边,最优路径一定不会出现交叉的情况,例如针对题图中的 \(3,4,6,7\) 这四个点,\(6 \rightarrow 3 \rightarrow 4 \rightarrow 7\) 这样的连线方式必然不如 \(6 \rightarrow 4 \rightarrow 3 \rightarrow 7\)

也不会出现 \(4 \rightarrow 2 \rightarrow 3\) 这样的路径,它显然不如 \(4 \rightarrow 3 \rightarrow 2\),且那样一来再从 \(3\) 出发连向另一边的点会导致出现交叉的情况

因此,最优路径一定会是先在一边从头按顺序走几步,再去另一边从头按顺序走几步,再回初始那一边从没到过的点开始再按着顺序走几步,再去另一边走几步,这样不断交替(注意第 \(n\) 个点可以接着走 \(1, 2, \dots\) 这些点,所以要考虑环形数组)

可以将输入的数据复制一份,破环成链以后最优路径走过的点一定是包含起点的一段长度为 \(n\) 的连续区间 \([l,r]\),且最后一定停在 \(l\) 或者 \(r\)

可以设 \(dp[i][j][0/1]\) 表示走完了 \([i,j]\) 区间,最后停在 \(i/j\) 时候的最优解,枚举最后一步是从 \(i+1\)\(i\),还是从 \(j\)\(i\),还是从 \(i\)\(j\),还是从 \(j-1\)\(j\) 完成转移,时间复杂度 \(O(n^2)\)

这道题最后不是要输出最优解的那个值,而是输出路径,因此我们需要在 DP 过程中记录方案

常用方法是开一个和 \(dp\) 一样大的 \(from\) 数组,记录转移点,比如这题 \(dp[i][j][0]\) 可以从 \(dp[i+1][j][0]\)\(dp[i+1][j][1]\) 转移过来,\(dp[i][j][1]\) 可以从 \(dp[i][j-1][0]\)\(dp[i][j-1][1]\) 转移过来,我们只需要在 \(from[i][j][0/1]\) 中存好 \(0/1\) 就知道它具体选的是哪种方案了

最后先扫一遍所有长度为 \(n\) 的区间,找到最优解,设为 \(dp[l][r][f]\),如果 \(f\)\(0\),代表当前点是 \(l\),如果 \(from[l][r][f]\) 也是 \(0\),就说明是从 \(dp[l+1][r][0]\) 且上一个点是 \(l+1\) 转移过来的,其它几种情况可以类似地判断,不断往前找,就能倒着把路径记录下来,再倒序输出即可

参考代码
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
const int N = 2005;
const double INF = 1e12;
double x[N], y[N], dp[N][N][2];
// from 记录从上一个状态的“左端点”还是“右端点”转移过来
int from[N][N][2], ans[N];
double distance(int i, int j) {
    double dx = x[i] - x[j], dy = y[i] - y[j];
    return sqrt(dx * dx + dy * dy);
}
int main()
{
    int n, k = 1;
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%lf%lf", &x[i], &y[i]);
        x[i + n] = x[i]; y[i + n] = y[i];
        if (y[i] > y[k]) k = i;
    }
    for (int i = 1; i <= 2 * n; i++)
        for (int j = 1; j <= 2 * n; j++)
            dp[i][j][0] = dp[i][j][1] = INF;
    dp[k][k][0] = dp[k][k][1] = 0; 
    dp[k + n][k + n][0] = dp[k + n][k + n][1] = 0; 
    for (int len = 2; len <= n; len++) {
        for (int i = 1; i <= 2 * n - len + 1; i++) {
            int j = i + len - 1;
            // dp[i][j][0]: to i
            double tmp = dp[i + 1][j][0] + distance(i, i + 1); // i+1 -> i
            if (tmp < dp[i][j][0]) {
                dp[i][j][0] = tmp; from[i][j][0] = 0;
            }
            tmp = dp[i + 1][j][1] + distance(i, j); // j -> i
            if (tmp < dp[i][j][0]) {
                dp[i][j][0] = tmp; from[i][j][0] = 1;
            }
            // dp[i][j][1]: to j
            tmp = dp[i][j - 1][0] + distance(i, j); // i -> j
            if (tmp < dp[i][j][1]) {
                dp[i][j][1] = tmp; from[i][j][1] = 0;
            }
            tmp = dp[i][j - 1][1] + distance(j - 1, j); // j-1 -> j
            if (tmp < dp[i][j][1]) {
                dp[i][j][1] = tmp; from[i][j][1] = 1;
            }
        }
    }
    int mini = 0, f = 0;
    double mindis = INF;
    for (int i = 1; i <= n; i++) {
        if (dp[i][i + n - 1][0] < mindis) {
            mindis = dp[i][i + n - 1][0]; mini = i; f = 0;
        }
        if (dp[i][i + n - 1][1] < mindis) {
            mindis = dp[i][i + n - 1][1]; mini = i; f = 1;
        }
    }
    int l = mini, r = mini + n - 1;
    for (int i = 1; i <= n; i++) {
        if (f == 0) {
            ans[i] = l;
            f = from[l][r][0]; l++;
        } else {
            ans[i] = r;
            f = from[l][r][1]; r--;
        }
    }
    for (int i = n; i >= 1; i--) 
        printf("%d%c", ans[i] > n ? ans[i] - n : ans[i], i == 1 ? '\n' : ' ');
    return 0;
}
posted @ 2023-11-10 17:05  RonChen  阅读(254)  评论(0编辑  收藏  举报