String Circle

String Circle

题目描述

给定 $k$ 个长度均为 $n$ 的字符串 $s_0, s_1, \ldots, s_{k−1}$。

请计算有多少种不同的长度为 $k$ 的字符串序列 $[t_0, t_1, \ldots, t_{k−1}]$,使得:

  • $s_0 = t_0 + t_1 + \cdots + t_{k-2} + t_{k-1}$;
  • $s_1 = t_1 + t_2 + \cdots + t_{k-1} + t_0$;
  • $\cdots$
  • $s_i = t_i + t_{i+1} + \cdots + t_{k-1} + t_0 + t_1 + \cdots + t_{k-1}$;
  • $\cdots$
  • $s_{k-1} = t_{k-1} + t_{0} + \cdots + t_{k-3} + t_{k-2}$;

对 $998244353$ 取模。注意 $[t_0, t_1, \ldots, t_{k−1}]$ 中可以存在空串。

其中,$s+t$ 表示字符串 $s$ 和字符串 $t$ 顺次拼接。

$2 \leq n, k, n \cdot k \leq 5 \times 10^6$。

输入描述:

第一行两个整数 $n,k$,分别表示字符串的长度和个数。

接下来 $k$ 行每行一个长度为 $n$,仅包含小写字母的字符串,依次分别为 $s_0, s_1, \ldots, s_{k−1}$。

输出描述:

一行一个整数表示答案。

示例1

输入

3 3
abc
bca
cab

输出

1

示例2

输入

3 3
aaa
aaa
aaa

输出

10

 

解题思路

  太难力,一个 $O(nm^2)$ 的 dp 想了几天都不知道怎么优化,后面还是看群友代码知道怎么做的。

  这里重新定义 $n$ 表示字符串个数,$m$ 表示字符串长度。

  先给出一开始的做法,定义状态 $f(i,j)$ 表示匹配了前 $i$ 个字符串且 $t$ 循环左移了 $j$ 个字符的方案数,容易知道初始时 $t = s_1$。根据 $s_{i-1}$ 可以循环左移多少个字符得到 $s_{i}$ 进行状态转移。记集合 $P$ 表示 $s_{i-1}$ 通过循环左移得到 $s_{i}$ 的所有移动次数,即对于 $p_k \in P$ 有 $s_{i-1}[1, p_k] = s_{i}[m, m-p_k+1]$ 且 $s_{i-1}[p_k+1, m] = s_{i}[1, m-p_k]$。因此状态转移方程就是 $f(i,j) = \sum\limits_{p_k \leq j}{f(i-1, j - p_k)}$。如果直接暴力转移的话整个 dp 的时间复杂度是 $O(nm^2)$,反正我是不知道怎么优化了,如果您知道麻烦留言告诉笨蛋博主。

  下面给出正解。可以知道对于任意的 $s_i$ 本质是通过 $s_1$ 循环左移得到的,因此每次状态转移可以通过与 $s_1$ 比较进行转移,而不是与 $s_{i-1}$ 比较。重新定义状态 $f(i,j)$ 表示匹配了前 $i$ 个字符串且 $s_i$ 是通过 $s_1$ 循环左移 $j$ 次得到的方案数。这里的 $j$ 必然满足 $s_{1}[1, j] = s_{i}[m, m-j+1]$ 且 $s_{1}[j+1, m] = s_{i}[1, m-j]$。根据 $s_{i-1}$ 通过 $s_1$ 循环左移 $k$ 次得到进行状态转移($k \leq j$,且如果 $s_1$ 无法循环左移 $k$ 次得到 $s_i$ 那么有 $f(i-1,k)=0$)。状态转移方程就是 $f(i,j) = \sum\limits_{k=0}^{j}{f(i-1,k)}$。通过累加前缀就可以实现 $O(1)$ 转移。

  另外一个关键问题是如何快速求出 $s_1$ 通过循环左移得到 $s_i$ 的次数。方法有字符串哈希(不卡自然溢出),kmp,z 函数,具体实现可以看代码。

  最后答案就是 $\sum\limits_{i=0}^{m}{f(n,i)}$。

  自然溢出的字符串哈希做法 AC 代码如下,时间复杂度为 $O(nm)$:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;
typedef unsigned long long ULL;

const int N = 2500005, P = 13331, mod = 998244353;

char s[N], t[N];
ULL hs[N], ht[N], p[N];
int f[N];

ULL query(ULL *h, int l, int r) {
    if (l > r) return 0;
    return h[r] - h[l - 1] * p[r - l + 1];
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n, m;
    cin >> m >> n >> s;
    memmove(s + 1, s, m + 1);
    p[0] = 1;
    for (int i = 1; i <= m; i++) {
        p[i] = p[i - 1] * P;
        hs[i] = hs[i - 1] * P + s[i];
    }
    f[0] = 1;
    for (int i = 1; i < n; i++) {
        cin >> t;
        memmove(t + 1, t, m + 1);
        for (int i = 1; i <= m; i++) {
            ht[i] = ht[i - 1] * P + t[i];
        }
        for (int j = 0, s = 0; j <= m; j++) {
            s = (s + f[j]) % mod;
            if (query(hs, 1, j) == query(ht, m - j + 1, m) && query(hs, j + 1, m) == query(ht, 1, m - j)) f[j] = s;
            else f[j] = 0;
        }
    }
    int ret = 0;
    for (int i = 0; i <= m; i++) {
        ret = (ret + f[i]) % mod;
    }
    cout << ret;
    
    return 0;
}

  kmp 做法 AC 代码如下,时间复杂度为 $O(nm)$:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;

const int N = 2500005, mod = 998244353;

char s[N * 2], t[N];
int ne[N];
int f[N], g[N];

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n, m;
    cin >> m >> n >> s;
    memcpy(s + m, s, m);
    f[0] = 1;
    for (int i = 1; i < n; i++) {
        cin >> t;
        ne[0] = -1;
        for (int i = 1, j = -1; i < m; i++) {
            while (j != -1 && t[j + 1] != t[i]) {
                j = ne[j];
            }
            if (t[j + 1] == t[i]) j++;
            ne[i] = j;
        }
        for (int i = 0; i <= m; i++) {
            if (i) g[i] = (g[i - 1] + f[i]) % mod;
            else g[i] = f[i];
            f[i] = 0;
        }
        for (int i = 0, j = -1; i < m << 1; i++) {
            while (j != -1 && t[j + 1] != s[i]) {
                j = ne[j];
            }
            if (t[j + 1] == s[i]) j++;
            if (j == m - 1) {
                f[i - m + 1] = g[i - m + 1];
                j = ne[j];
            }
        }
    }
    int ret = 0;
    for (int i = 0; i <= m; i++) {
        ret = (ret + f[i]) % mod;
    }
    cout << ret;
    
    return 0;
}

  z 函数做法 AC 代码如下,时间复杂度为 $O(nm)$:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;

const int N = 2500005, mod = 998244353;

int z[N * 3];
int f[N];

void z_function(string s) {
    int n = s.size();
    for (int i = 1, j = 0; i < n; i++) {
        if (i < j + z[j]) z[i] = min(j + z[j] - i, z[i - j]);
        else z[i] = 0;
        while (i + z[i] < n && s[i + z[i]] == s[z[i]]) {
            z[i]++;
        }
        if (i + z[i] > j + z[j]) j = i;
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n, m;
    string s;
    cin >> m >> n >> s;
    s += s;
    f[0] = 1;
    for (int i = 1; i < n; i++) {
        string t;
        cin >> t;
        z_function(t + s);
        for (int j = 0, s = 0; j <= m; j++) {
            s = (s + f[j]) % mod;
            if (z[m + j] >= m) f[j] = s;
            else f[j] = 0;
        }
    }
    int ret = 0;
    for (int i = 0; i <= m; i++) {
        ret = (ret + f[i]) % mod;
    }
    cout << ret;
    
    return 0;
}

 

参考资料

  Accepted极限代码巅峰赛 lgkm39 提交的代码:https://ac.nowcoder.com/acm/contest/view-submission?submissionId=72962880

posted @ 2024-11-08 17:16  onlyblues  阅读(2)  评论(0编辑  收藏  举报
Web Analytics