String Circle
String Circle
题目描述
给定 $k$ 个长度均为 $n$ 的字符串 $s_0, s_1, \ldots, s_{k−1}$。
请计算有多少种不同的长度为 $k$ 的字符串序列 $[t_0, t_1, \ldots, t_{k−1}]$,使得:
- $s_0 = t_0 + t_1 + \cdots + t_{k-2} + t_{k-1}$;
- $s_1 = t_1 + t_2 + \cdots + t_{k-1} + t_0$;
- $\cdots$
- $s_i = t_i + t_{i+1} + \cdots + t_{k-1} + t_0 + t_1 + \cdots + t_{k-1}$;
- $\cdots$
- $s_{k-1} = t_{k-1} + t_{0} + \cdots + t_{k-3} + t_{k-2}$;
对 $998244353$ 取模。注意 $[t_0, t_1, \ldots, t_{k−1}]$ 中可以存在空串。
其中,$s+t$ 表示字符串 $s$ 和字符串 $t$ 顺次拼接。
$2 \leq n, k, n \cdot k \leq 5 \times 10^6$。
输入描述:
第一行两个整数 $n,k$,分别表示字符串的长度和个数。
接下来 $k$ 行每行一个长度为 $n$,仅包含小写字母的字符串,依次分别为 $s_0, s_1, \ldots, s_{k−1}$。
输出描述:
一行一个整数表示答案。
示例1
输入
3 3
abc
bca
cab
输出
1
示例2
输入
3 3
aaa
aaa
aaa
输出
10
解题思路
太难力,一个 $O(nm^2)$ 的 dp 想了几天都不知道怎么优化,后面还是看群友代码知道怎么做的。
这里重新定义 $n$ 表示字符串个数,$m$ 表示字符串长度。
先给出一开始的做法,定义状态 $f(i,j)$ 表示匹配了前 $i$ 个字符串且 $t$ 循环左移了 $j$ 个字符的方案数,容易知道初始时 $t = s_1$。根据 $s_{i-1}$ 可以循环左移多少个字符得到 $s_{i}$ 进行状态转移。记集合 $P$ 表示 $s_{i-1}$ 通过循环左移得到 $s_{i}$ 的所有移动次数,即对于 $p_k \in P$ 有 $s_{i-1}[1, p_k] = s_{i}[m, m-p_k+1]$ 且 $s_{i-1}[p_k+1, m] = s_{i}[1, m-p_k]$。因此状态转移方程就是 $f(i,j) = \sum\limits_{p_k \leq j}{f(i-1, j - p_k)}$。如果直接暴力转移的话整个 dp 的时间复杂度是 $O(nm^2)$,反正我是不知道怎么优化了,如果您知道麻烦留言告诉笨蛋博主。
下面给出正解。可以知道对于任意的 $s_i$ 本质是通过 $s_1$ 循环左移得到的,因此每次状态转移可以通过与 $s_1$ 比较进行转移,而不是与 $s_{i-1}$ 比较。重新定义状态 $f(i,j)$ 表示匹配了前 $i$ 个字符串且 $s_i$ 是通过 $s_1$ 循环左移 $j$ 次得到的方案数。这里的 $j$ 必然满足 $s_{1}[1, j] = s_{i}[m, m-j+1]$ 且 $s_{1}[j+1, m] = s_{i}[1, m-j]$。根据 $s_{i-1}$ 通过 $s_1$ 循环左移 $k$ 次得到进行状态转移($k \leq j$,且如果 $s_1$ 无法循环左移 $k$ 次得到 $s_i$ 那么有 $f(i-1,k)=0$)。状态转移方程就是 $f(i,j) = \sum\limits_{k=0}^{j}{f(i-1,k)}$。通过累加前缀就可以实现 $O(1)$ 转移。
另外一个关键问题是如何快速求出 $s_1$ 通过循环左移得到 $s_i$ 的次数。方法有字符串哈希(不卡自然溢出),kmp,z 函数,具体实现可以看代码。
最后答案就是 $\sum\limits_{i=0}^{m}{f(n,i)}$。
自然溢出的字符串哈希做法 AC 代码如下,时间复杂度为 $O(nm)$:
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
const int N = 2500005, P = 13331, mod = 998244353;
char s[N], t[N];
ULL hs[N], ht[N], p[N];
int f[N];
ULL query(ULL *h, int l, int r) {
if (l > r) return 0;
return h[r] - h[l - 1] * p[r - l + 1];
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n, m;
cin >> m >> n >> s;
memmove(s + 1, s, m + 1);
p[0] = 1;
for (int i = 1; i <= m; i++) {
p[i] = p[i - 1] * P;
hs[i] = hs[i - 1] * P + s[i];
}
f[0] = 1;
for (int i = 1; i < n; i++) {
cin >> t;
memmove(t + 1, t, m + 1);
for (int i = 1; i <= m; i++) {
ht[i] = ht[i - 1] * P + t[i];
}
for (int j = 0, s = 0; j <= m; j++) {
s = (s + f[j]) % mod;
if (query(hs, 1, j) == query(ht, m - j + 1, m) && query(hs, j + 1, m) == query(ht, 1, m - j)) f[j] = s;
else f[j] = 0;
}
}
int ret = 0;
for (int i = 0; i <= m; i++) {
ret = (ret + f[i]) % mod;
}
cout << ret;
return 0;
}
kmp 做法 AC 代码如下,时间复杂度为 $O(nm)$:
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 2500005, mod = 998244353;
char s[N * 2], t[N];
int ne[N];
int f[N], g[N];
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n, m;
cin >> m >> n >> s;
memcpy(s + m, s, m);
f[0] = 1;
for (int i = 1; i < n; i++) {
cin >> t;
ne[0] = -1;
for (int i = 1, j = -1; i < m; i++) {
while (j != -1 && t[j + 1] != t[i]) {
j = ne[j];
}
if (t[j + 1] == t[i]) j++;
ne[i] = j;
}
for (int i = 0; i <= m; i++) {
if (i) g[i] = (g[i - 1] + f[i]) % mod;
else g[i] = f[i];
f[i] = 0;
}
for (int i = 0, j = -1; i < m << 1; i++) {
while (j != -1 && t[j + 1] != s[i]) {
j = ne[j];
}
if (t[j + 1] == s[i]) j++;
if (j == m - 1) {
f[i - m + 1] = g[i - m + 1];
j = ne[j];
}
}
}
int ret = 0;
for (int i = 0; i <= m; i++) {
ret = (ret + f[i]) % mod;
}
cout << ret;
return 0;
}
z 函数做法 AC 代码如下,时间复杂度为 $O(nm)$:
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 2500005, mod = 998244353;
int z[N * 3];
int f[N];
void z_function(string s) {
int n = s.size();
for (int i = 1, j = 0; i < n; i++) {
if (i < j + z[j]) z[i] = min(j + z[j] - i, z[i - j]);
else z[i] = 0;
while (i + z[i] < n && s[i + z[i]] == s[z[i]]) {
z[i]++;
}
if (i + z[i] > j + z[j]) j = i;
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n, m;
string s;
cin >> m >> n >> s;
s += s;
f[0] = 1;
for (int i = 1; i < n; i++) {
string t;
cin >> t;
z_function(t + s);
for (int j = 0, s = 0; j <= m; j++) {
s = (s + f[j]) % mod;
if (z[m + j] >= m) f[j] = s;
else f[j] = 0;
}
}
int ret = 0;
for (int i = 0; i <= m; i++) {
ret = (ret + f[i]) % mod;
}
cout << ret;
return 0;
}
参考资料
Accepted极限代码巅峰赛 lgkm39 提交的代码:https://ac.nowcoder.com/acm/contest/view-submission?submissionId=72962880
本文来自博客园,作者:onlyblues,转载请注明原文链接:https://www.cnblogs.com/onlyblues/p/18535298