[动态规划] Codeforces 1499E Chaotic Merge
题目大意
给定两个仅由小写字母组成的字符串 \(x\) 和 \(y\)。
如果一个序列仅包含 \(|x|\) 个 \(0\) 和 \(|y|\) 个 \(1\),则称这个序列为合并序列。
字符串 \(z\) 初始为空,按如下规则由合并序列 \(a\) 生成:
- 如果 \(a_i=0\),则把 \(x\) 开头的一个字符加到 \(z\) 的末尾;
- 如果 \(a_i=1\),则把 \(y\) 开头的一个字符加到 \(z\) 的末尾。
两个合并序列 \(a\) 和 \(b\) 被认为是不同的,如果存在某个 \(i\),使得 \(a_i\neq b_i\)。
若一个字符串任意两个相邻位置上的字符都不同,则我们称该字符串是混乱的。
定义 \(f(l_1,r_1,l_2,r_2)\) 表示能从 \(x\) 的子串 \(x[l_1,r_1]\) 和 \(y\) 的子串 \(y[l_2,r_2]\) 生成混乱的字符串的不同的合并序列的数量,要求子串非空。
求 \(\sum \limits_{1 \le l_1 \le r_1 \le |x| , 1 \le l_2 \le r_2 \le |y|} f(l_1, r_1, l_2, r_2)\),答案对 \(998244353\) 取模。
\((1\leq |x|,|y|\leq 1000)\)
题解
直接枚举子串然后算有多少个不同的合并序列肯定不好做,所以我们考虑dp。
设 \(dp[i][j][0]\) 表示以第 \(i\) 个位置结尾的 \(x\) 的子串和以第 \(j\) 个位置结尾的 \(y\) 的子串进行合并,并且合并后的混乱的字符串以 \(x\) 的第 \(i\) 个位置结尾,不同的合并序列的数量。
设 \(dp[i][j][1]\) 表示以第 \(i\) 个位置结尾的 \(x\) 的子串和以第 \(j\) 个位置结尾的 \(y\) 的子串进行合并,并且合并后的混乱的字符串以 \(y\) 的第 \(j\) 个位置结尾,不同的合并序列的数量。
那么我们只需枚举倒数第二个位置上是什么,同时要满足它和最后一个位置上的字符不同。
若 \(x[i-1]\neq x[i]\),则 \(dp[i][j][0]+=dp[i-1][j][0]\)。
若 \(y[j]\neq x[i]\),则 \(dp[i][j][0]+=dp[i-1][j][1]\)。
若 \(x[i]\neq y[j]\),则 \(dp[i][j][1]+=dp[i][j-1][0]\)。
若 \(y[j-1]\neq y[j]\),则 \(dp[i][j][1]+=dp[i][j-1][1]\)。
注意到以上转移必须满足上一个状态中 \(x\) 和 \(y\) 的两个子串都非空。但我们可以只取一个字符作为一个子串合并到最后,所以上一个状态的 \(x\) 或 \(y\) 是可以为空的,但是空串我们又不计入答案。所以我们维护 \(dpx[i]\) 表示 \(x\) 中有多少个以第 \(i\) 个位置结尾的混乱的子串,\(dpy[j]\) 表示 \(y\) 中有多少个以第 \(j\) 个位置结尾的混乱的子串,则有:
若 \(x[i]\neq x[i-1]\),则 \(dpx[i]=dpx[i-1]+1\),否则 \(dpx[i]=1\)。
若 \(y[j]\neq y[j-1]\),则 \(dpy[j]=dpy[j-1]+1\),否则 \(dpy[j]=1\)。
若 \(x[i]\neq y[j]\), \(dp[i][j][0]+=dpy[j]\)。
若 \(x[i]\neq y[j]\), \(dp[i][j][1]+=dpx[i]\)。
时间复杂度 \(O(|x||y|)\)。
Code
#include <bits/stdc++.h>
using namespace std;
#define RG register int
#define LL long long
const LL MOD = 998244353;
char x[1005], y[1005];
LL dp[1001][1001][2], dpx[1001], dpy[1001];
int n, m;
int main() {
scanf("%s", x + 1);
scanf("%s", y + 1);
n = strlen(x + 1);
m = strlen(y + 1);
LL ans = 0;
for (int i = 1;i <= m;++i) {
dpy[i] = 1;
if (y[i - 1] != y[i]) dpy[i] = (dpy[i] + dpy[i - 1]) % MOD;
}
for (int i = 1;i <= n;++i) {
dpx[i] = 1;
if (x[i - 1] != x[i]) dpx[i] = (dpx[i] + dpx[i - 1]) % MOD;
for (int j = 1;j <= m;++j) {
if (x[i] != y[j]) dp[i][j][0] = (dp[i][j][0] + dpy[j]) % MOD;
if (x[i] != y[j]) dp[i][j][1] = (dp[i][j][1] + dpx[i]) % MOD;
if (x[i - 1] != x[i]) dp[i][j][0] = (dp[i][j][0] + dp[i - 1][j][0]) % MOD;
if (y[j] != x[i]) dp[i][j][0] = (dp[i][j][0] + dp[i - 1][j][1]) % MOD;
if (x[i] != y[j]) dp[i][j][1] = (dp[i][j][1] + dp[i][j - 1][0]) % MOD;
if (y[j - 1] != y[j]) dp[i][j][1] = (dp[i][j][1] + dp[i][j - 1][1]) % MOD;
ans = (ans + dp[i][j][0] + dp[i][j][1]) % MOD;
}
}
printf("%I64d\n", ans);
return 0;
}