线性 DP

最长上升子序列问题是一个经典的线性动态规划问题。

例题:B3637 最长上升子序列

分析:设原始数组为 \(a\),定义状态 \(dp_i\) 表示以 \(a_i\) 结尾的上升子序列的最大长度。注意这个状态定义中有两个重点,第一个重点是 \(dp_i\) 只维护所有原始序列中以 \(a_i\) 结尾的上升子序列的信息。这样可以发现,对于每个上升子序列,都会唯一被归类到 \(dp\) 的某个状态中。第二个重点是对于所有以 \(a_i\) 结尾的上升子序列,只记录长度最长的那个子序列的长度。这是因为最优子结构性质,如果以 \(a_i\) 结尾有很多上升子序列,肯定是保留最长的那个更划算,因为它后面接数字之后能得到更长的上升子序列。而且这种方式能够满足无后效性,因为如果在所有以 \(a_i\) 结尾的上升子序列后面再接数字,能接哪个数字完全取决于 \(a_i\),跟 \(a_i\) 前面的数无关。所以这种状态定义方式同时满足无后效性和最优子结构。

考虑如何进行状态转移,也就是寻找一个递推关系,用之前计算过的某些 \(dp\) 值来计算 \(dp_i\)。考虑 \(dp_i\) 这个状态要以 \(a_i\) 结尾,只需要关心它能接到前面哪些子序列的后面。一种情况是,自成一段,则长度为 \(1\),那么 \(dp_i = 1\);另一种情况是,对于所有 \(i\) 前面的位置 \(j\),且满足 \(a_j < a_i\) 的,\(dp_i = dp_j + 1\),即在以 \(a_j\) 结尾的最长上升子序列的基础上,再增加一个自己带来的长度 \(1\)。为了使得 \(dp_i\) 的值最大,显然应该对于所有 \(j\),取 \(dp_j + 1\) 的最大值。即 \(dp_i = \max (dp_j + 1)\),其中要满足 \(j < i\) 并且 \(a_j < a_i\)

最终的答案就是所有 \(dp_i\) 中的最大值,因为不能确定整个序列的最长上升子序列是以哪个数结尾的,所以每个数作为结尾都要考虑一遍。本算法的时间复杂度为 \(O(n^2)\):因为要枚举以第 \(i\) 个数结尾的情况去计算 \(dp_i\),因此需要枚举 \(n\) 次;而在计算每个 \(dp_i\) 时,又需要把 \(i\) 前面的每个位置 \(j\) 枚举一遍。

参考代码
#include <cstdio>
#include <algorithm>
using std::max;
const int N = 5005;
int a[N], dp[N];
int main()
{
    int n; scanf("%d", &n);
    for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
    int ans = 0;
    for (int i = 1; i <= n; i++) {
        dp[i] = 1;
        for (int j = 1; j < n; j++) {
            if (a[j] <  a[i]) dp[i] = max(dp[i], dp[j] + 1);
        }
        ans = max(ans, dp[i]);
    }
    printf("%d\n", ans);
    return 0;
}

还有一个时间复杂度更低的做法。用 \(dp_i\) 表示长度为 \(i\) 的上升子序列中最小的结尾。注意,这个 \(dp_i\) 的定义与前一种方式不同。如果有多个长度为 \(i\) 的上升子序列,记录所有这样的子序列中结尾最小的那个。这满足最优子结构,因为拥有最小结尾的上升子序列,更有可能被后面的数接上,形成更长的上升子序列。

在一开始,只考虑 \(a_1\),这时候有唯一的长度为 \(1\) 的上升子序列,它的结尾是 \(a_1\)

假设数组 \(a\) 等于 \([1, 7, 3, 5, 9, 4, 8]\)。接下来,一个数一个数考虑,把数组 \(a\) 中每个数字考虑进来,分析 \(dp\) 数组的变化。下一个数是 \(a_2 = 7\),它可以接在前面的 \(1\) 的后面,形成长度为 \(2\) 的上升子序列,结尾是 \(7\)。因为之前没有过长度为 \(2\) 的上升子序列,所以直接在 \(dp_2\) 位置写入 \(7\)

下一个数是 \(a_3 = 3\),目前长度为 \(1\) 的子序列是以 \(1\) 结尾的,长度为 \(2\) 的子序列最小结尾是 \(7\),那么新来的这个 \(3\) 肯定不能接在 \(7\) 后面,只能接在 \(1\) 后面,得到一个长度为 \(2\) 的上升子序列,结尾是 \(3\),比之前的 \(dp_2 = 7\) 要小,所以修改 \(dp_2 = 3\)

下一个数是 \(a_4 = 5\),它可以接在长度为 \(2\) 结尾为 \(3\) 的子序列后面,得到长度为 \(3\),结尾为 \(5\) 的上升子序列。

下一个数是 \(a_5 = 9\),它可以接在长度为 \(3\) 结尾为 \(5\) 的子序列后面,得到长度为 \(4\),结尾为 \(9\) 的上升子序列。

到目前为止,大概可以总结出一个算法。一个接一个地考虑数组 \(a\) 中的每个数,对于当前的 \(a_i\),首先看它是否比 \(dp\) 中目前最后一个有效元素大,如果是,那么就可以接在最后面,相当于得到了一个更长的子序列,以 \(a_i\) 结尾;如果 \(a_i\) 不比 \(dp\) 最后一个有效元素大,那么就在 \(dp\) 中,从右往左找到最靠右边的、比 \(a_i\) 小的数,接到它的后面。相当于把 \(dp\) 中最靠左的第一个大于或等于 \(a_i\) 的数修改为 \(a_i\)

例如,下一个考虑的数是 \(a_6 = 4\),就会将 \(dp_3\) 替换成 \(4\)

同理,对于 \(a_7 = 8\),它会替换 \(dp_4\)

image

最终,最长上升子序列的长度是 \(4\),并且最小以 \(8\) 结尾。

分析一下这个做法的时间复杂度,对于每个 \(a_i\),要么接在 \(dp\) 的末尾,要么遍历数组 \(dp\) 寻找最靠左的大于或等于 \(a_i\) 的数进行替换,最坏情况下时间复杂度是 \(O(n)\),总的时间复杂度是 \(O(n^2)\),看起来并没有变优。

实际上,可以发现 \(dp\) 是单调的,所以“遍历 \(dp\) 寻找最靠左的大于或等于 \(a_i\) 的数进行替换”这一操作,是不需要完整遍历的,可以在有序数组上进行二分查找,每次查找的时间复杂度变为 \(O(\log n)\),总的时间复杂度为 \(O(n \log n)\)

参考代码
#include <cstdio>
#include <algorithm>
using std::max;
using std::lower_bound;
const int N = 5005;
int a[N], dp[N];
int main()
{
    int n; scanf("%d", &n);
    for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
    int ans = 0; // 记录最长上升子序列的长度
    for (int i = 1; i <= n; i++) {
        // 在dp[1]~dp[ans]间进行二分查找
        int idx = lower_bound(dp + 1, dp + ans + 1, a[i]) - dp; 
        if (idx > ans) ans++; // 可以接在dp数组最后一个有效元素后面,长度加1
        dp[idx] = a[i]; // 将二分出的位置替换为a[i]
    }
    printf("%d\n", ans);
    return 0;
}

例题:P1020 [NOIP1999 提高组] 导弹拦截

分析:先考虑第 \(1\) 问,只有 \(1\) 套系统的话,最多可以拦截多少导弹。题目要求“每一发炮弹都不能高于前一发的高度”,其实就是找一个最长的子序列,满足子序列中后一个元素不能比前一个大,只能比前一个小或相等,可以称为最长不上升子序列。

题目第 \(2\) 问是需要多少套系统可以拦截所有的导弹,其实是问最少使用多少个不上升子序列可以覆盖整个区间。针对这类问题,有一个 Dilworth 定理。要求这样的子序列最少多少个,等价于求原序列的最长上升子序列的长度

参考代码
#include <cstdio>
#include <algorithm>
using std::lower_bound;
using std::upper_bound;
const int N = 100005;
int a[N], dp[N];
int main()
{
    int n = 0, x;
    while (scanf("%d", &x) != -1) {
        a[++n] = x;
    }
    // 第1问
    // 求最长不上升子序列的长度,相当于倒过来求最长不下降子序列的长度
    int ans = 0;
    for (int i = n; i >= 1; i--) {
        // 注意:最长上升子序列是lower_bound,最长不下降子序列是upper_bound 
        int idx = upper_bound(dp + 1, dp + ans + 1, a[i]) - dp;
        if (idx > ans) ans++;
        dp[idx] = a[i];
    }
    printf("%d\n", ans);
    // 第2问
    // 等价于求最长上升子序列的长度
    ans = 0;
    for (int i = 1; i <= n; i++) {
        int idx = lower_bound(dp + 1, dp + ans + 1, a[i]) - dp;
        if (idx > ans) ans++;
        dp[idx] = a[i];
    }
    printf("%d\n", ans);
    return 0;
}

例题:最长公共子序列

给出两个字符串,求最长的这样的子序列,要求满足子序列的每个字符都能在两个原字符串中找到,而且每个字符的先后顺序和原字符串中的先后顺序一致。
例如,两个字符串分别是 abcfbcabfcab,它们的最长公共子序列长度是 \(4\),如 abfc

设两个字符串分别为 \(s1\)\(s2\),长度分别为 \(len1\)\(len2\)。定义二维状态 \(dp_{i,j}\) 表示 \(s1\) 的前 \(i\) 个字符串形成的子串与 \(s2\) 的前 \(j\) 个字符形成的子串的最长公共子序列的长度。

这个状态定义,还是遵循最优子结构的思想。要解决的是两个比较长的字符串之间的问题,对两个字符串各自截取前若干个字符形成的子串,看看子串里面的答案能否计算出来。如果能,把子串延长一些,看看能否转移,最终计算出的 \(dp_{len1,len2}\) 就是想求的结果。

状态转移方程:\(dp_{i,j} = \begin{cases} dp_{i-1,j-1} + 1, & s1_i = s2_j \\ \max (dp_{i,j-1}, dp_{i-1,j}), & s1_i \ne s2_j \end{cases}\)

考虑两个子串的最后一位 \(s1_i\)\(s2_j\),如果它们相等,那么就可以对答案贡献 \(1\) 的长度。\(s1\) 的前 \(i-1\) 个字符与 \(s2\) 的前 \(j-1\) 个字符能形成的最长公共子序列的长度,再接上新贡献的 \(1\),也就是 \(dp_{i-1,j-1} + 1\)

若两个子串的最后一位 \(s1_i\)\(s2_j\) 不想等,既然它们不能配对为答案做出贡献,不如丢弃其中的某一个。如丢弃 \(s2\) 的第 \(j\) 个字符,看 \(s1\) 的前 \(i\) 个字符与 \(s2\) 的前 \(j-1\) 个字符能够形成的答案是多少,再考虑 \(s1\) 的前 \(i-1\) 位和 \(s2\) 的前 \(j\) 位形成的答案是多少,比较这两个里面哪个更大,那么就构成当前的结果,也就是 \(\max (dp_{i,j-1}, dp_{i-1,j})\)

考虑边界情况,容易发现 \(i=0\)\(j=0\) 时是初始状态,显然这些结果都是 \(0\),因为此时至少有其中一个是空串,无法形成公共子序列。

总的时间复杂度是 \(O(n^2)\)

例题:AT_dp_f LCS

分析:本题需要在求最长公共子序列时把这个序列找出来。一个直观的想法是:除了记录每个状态的最长公共子序列的长度,再配一个相应的数组记录每个状态对应的字符串。状态转移时,除了转移长度,也转移相应的字符串。由于涉及到大量的字符串复制,这个做法比较慢,并且要占用很大的空间。

另一个思路是,记录每个状态是转移自前面的哪个状态的,也就是记录每个状态的父亲状态。在状态转移方程中,可以看到,对于 \(dp_{i,j}\),它的值是从 \(dp_{i-1,j-1}, dp_{i-1,j}, dp_{i,j-1}\) 三个中的某一个转移过来的。所以对于每个状态,可以区分这三种转移来源。最后的结果是看 \(dp_{len1, len2}\),则根据该状态是三种转移中的哪一种倒推回去,直到边界条件。在这个过程中,每当发现某个 \(dp_{i,j}\) 的来源是 \(dp_{i-1,j-1}\) 时就说明最长公共子序列中包含 \(s_i\)\(t_j\) 这个字符(因为此时两者相等,取哪个都一样),把这个过程中涉及到的字符连起来倒序输出即为答案(因为第一个连接到的字符实际上是整个最长公共子序列中的最后一个)。

参考代码
#include <cstdio>
#include <cstring>
const int N = 3005;
char s[N], t[N], ans[N];
int dp[N][N], from[N][N];
int main()
{
    scanf("%s%s", s + 1, t + 1);
    int lens = strlen(s + 1), lent = strlen(t + 1);
    for (int i = 1; i <= lens; i++) {
        for (int j = 1; j <= lent; j++) {
            if (s[i] == t[j]) {
                dp[i][j] = dp[i - 1][j - 1] + 1;
                from[i][j] = 0; 
            } else {
                if (dp[i - 1][j] > dp[i][j - 1]) {
                    dp[i][j] = dp[i - 1][j];
                    from[i][j] = 1;
                } else {
                    dp[i][j] = dp[i][j - 1];
                    from[i][j] = 2;
                }
            }
        }
    }
    int x = lens, y = lent;
    int n = 0;
    while (x > 0 && y > 0) {
        if (from[x][y] == 0) { // 转移来源标记等于0表示是一次公共字符
            ans[++n] = s[x];
            x--; y--;
        } else if (from[x][y] == 1) {
            x--;
        } else {
            y--;
        }
    }
    for (int i = n; i >= 1; i--) printf("%c", ans[i]);
    return 0;
}

习题:P9753 [CSP-S 2023] 消消乐

解题思路(35 分)

对于一个固定的字符串,怎么判断它“可消除”?

可以采用类似括号匹配的方法:维护一个栈,按顺序遍历字符串,若当前字符等于栈顶,则将栈顶弹出,否则将当前字符入栈。如果最终栈为空则说明整个串是“可消除的”。

因此最直接的做法就是枚举所有的子串,对每个子串用一个栈来模拟这个过程,验证是否“可消除”。

时间复杂度为 \(O(n^3)\),期望得分 \(35\) 分。

参考代码
#include <cstdio>
#include <stack>
using std::stack;
using ll = long long;
const int N = 2000005;
char s[N];
int main()
{
    int n; scanf("%d", &n);
    scanf("%s", s + 1);
    ll ans = 0;
    for (int i = 1; i <= n; i++) {
        for (int j = i; j <= n; j++) {
            // 子串i~j
            stack<char> stk;
            for (int k = i; k <= j; k++) {
                if (!stk.empty() && stk.top() == s[k]) stk.pop();
                else stk.push(s[k]);
            }
            if (stk.empty()) ans++;
        }
    }
    printf("%lld\n", ans);
    return 0;
}
解题思路(50 分)

在前面那个做法中可以发现,考虑对于子串 \([i,j]\)\([i,j+1]\) 的验证过程,除了第 \(j+1\) 个字符以外,其余字符处理的逻辑是一样的,所以不需要对每个子串都重新开始维护一个栈。当枚举某个左端点时,维护一个栈,遍历这个左端点右侧的每个字符,每当处理完当前字符后看栈是否为空即可判断是否“可消除”。

时间复杂度为 \(O(n^2)\),期望得分 \(50\) 分。

参考代码
#include <cstdio>
#include <stack>
using std::stack;
using ll = long long;
const int N = 2000005;
char s[N];
int main()
{
    int n; scanf("%d", &n);
    scanf("%s", s + 1);
    ll ans = 0;
    for (int i = 1; i <= n; i++) {
        stack<char> stk; 
        for (int j = i; j <= n; j++) {
            // 子串i~j
            if (!stk.empty() && stk.top() == s[j]) stk.pop();
            else stk.push(s[j]);
            
            if (stk.empty()) ans++;
        }
    }
    printf("%lld\n", ans);
    return 0;
}
解题思路

分析数据范围,站在常见的线性 DP 问题视角思考这个问题。

\(dp_i\) 表示以第 \(i\) 个字符结尾的“可消除”子串数量。

那么对于每个 \(dp_i\),从哪个位置转移过来呢?考虑 \(i\) 左侧的某个位置 \(j\),如果可以转移过来,说明 \([j+1, i]\) 是一个“可消除的”子串。对于计数问题,要保证不重不漏,则 \(j\) 需要是最后一个可以满足 \([j+1,i]\) “可消除”的子串。

\(last_i\) 表示以第 \(i\) 个字符结尾的最短“可消除”字符串,则 \(dp_i = dp_{last_i - 1} + 1\)

考虑如何计算 \(last_i\),首先,因为是要最短的“可消除”字符串,那么必然有 \(s_{last_i} = s_i\),也就是说 \([last_i + 1, i - 1]\) 是一个“可消除”字符串。因此可以持续迭代 \(last_i \leftarrow last_{last_i} - 1\),其中初始值是 \(i-1\),直到 \(s_{last_i} = s_i\) 或跳出字符串范围(即说明以 \(s_i\) 结尾不可消除)。

image

这个做法的时间复杂度是 \(An\),其中 \(A\) 是字符集大小,在本题中为 \(26\)

如何证明这个时间复杂度?可以参考 暴力跳做法的复杂度证明,可以证明每一个位置最多被后面 \(A\) 个位置跳过来。

参考代码
#include <cstdio>
#include <stack>
using std::stack;
using ll = long long;
const int N = 2000005;
char s[N];
int dp[N], last[N];
int main()
{
    int n; scanf("%d%s", &n, s + 1);
    ll ans = 0;
    for (int i = 1; i <= n; i++) {
        int j = i - 1;
        while (j > 0 && s[j] != s[i]) {
            j = last[j] - 1;
        }
        if (j > 0) {
            dp[i] = dp[j - 1] + 1;
            last[i] = j;
        }
        ans += dp[i];
    }
    printf("%lld\n", ans);
    return 0;
}
posted @ 2024-10-23 21:51  RonChen  阅读(37)  评论(0编辑  收藏  举报