String Circle

String Circle

题目描述

给定 k 个长度均为 n 的字符串 s0,s1,,sk1

请计算有多少种不同的长度为 k 的字符串序列 [t0,t1,,tk1],使得:

  • s0=t0+t1++tk2+tk1
  • s1=t1+t2++tk1+t0
  • si=ti+ti+1++tk1+t0+t1++tk1
  • sk1=tk1+t0++tk3+tk2

998244353 取模。注意 [t0,t1,,tk1] 中可以存在空串。

其中,s+t 表示字符串 s 和字符串 t 顺次拼接。

2n,k,nk5×106

输入描述:

第一行两个整数 n,k,分别表示字符串的长度和个数。

接下来 k 行每行一个长度为 n,仅包含小写字母的字符串,依次分别为 s0,s1,,sk1

输出描述:

一行一个整数表示答案。

示例1

输入

3 3
abc
bca
cab

输出

1

示例2

输入

3 3
aaa
aaa
aaa

输出

10

 

解题思路

  太难力,一个 O(nm2) 的 dp 想了几天都不知道怎么优化,后面还是看群友代码知道怎么做的。

  这里重新定义 n 表示字符串个数,m 表示字符串长度。

  先给出一开始的做法,定义状态 f(i,j) 表示匹配了前 i 个字符串且 t 循环左移了 j 个字符的方案数,容易知道初始时 t=s1。根据 si1 可以循环左移多少个字符得到 si 进行状态转移。记集合 P 表示 si1 通过循环左移得到 si 的所有移动次数,即对于 pkPsi1[1,pk]=si[m,mpk+1]si1[pk+1,m]=si[1,mpk]。因此状态转移方程就是 f(i,j)=pkjf(i1,jpk)。如果直接暴力转移的话整个 dp 的时间复杂度是 O(nm2),反正我是不知道怎么优化了,如果您知道麻烦留言告诉笨蛋博主。

  下面给出正解。可以知道对于任意的 si 本质是通过 s1 循环左移得到的,因此每次状态转移可以通过与 s1 比较进行转移,而不是与 si1 比较。重新定义状态 f(i,j) 表示匹配了前 i 个字符串且 si 是通过 s1 循环左移 j 次得到的方案数。这里的 j 必然满足 s1[1,j]=si[m,mj+1]s1[j+1,m]=si[1,mj]。根据 si1 通过 s1 循环左移 k 次得到进行状态转移(kj,且如果 s1 无法循环左移 k 次得到 si 那么有 f(i1,k)=0)。状态转移方程就是 f(i,j)=k=0jf(i1,k)。通过累加前缀就可以实现 O(1) 转移。

  另外一个关键问题是如何快速求出 s1 通过循环左移得到 si 的次数。方法有字符串哈希(不卡自然溢出),kmp,z 函数,具体实现可以看代码。

  最后答案就是 i=0mf(n,i)

  自然溢出的字符串哈希做法 AC 代码如下,时间复杂度为 O(nm)

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;
typedef unsigned long long ULL;

const int N = 2500005, P = 13331, mod = 998244353;

char s[N], t[N];
ULL hs[N], ht[N], p[N];
int f[N];

ULL query(ULL *h, int l, int r) {
    if (l > r) return 0;
    return h[r] - h[l - 1] * p[r - l + 1];
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n, m;
    cin >> m >> n >> s;
    memmove(s + 1, s, m + 1);
    p[0] = 1;
    for (int i = 1; i <= m; i++) {
        p[i] = p[i - 1] * P;
        hs[i] = hs[i - 1] * P + s[i];
    }
    f[0] = 1;
    for (int i = 1; i < n; i++) {
        cin >> t;
        memmove(t + 1, t, m + 1);
        for (int i = 1; i <= m; i++) {
            ht[i] = ht[i - 1] * P + t[i];
        }
        for (int j = 0, s = 0; j <= m; j++) {
            s = (s + f[j]) % mod;
            if (query(hs, 1, j) == query(ht, m - j + 1, m) && query(hs, j + 1, m) == query(ht, 1, m - j)) f[j] = s;
            else f[j] = 0;
        }
    }
    int ret = 0;
    for (int i = 0; i <= m; i++) {
        ret = (ret + f[i]) % mod;
    }
    cout << ret;
    
    return 0;
}

  kmp 做法 AC 代码如下,时间复杂度为 O(nm)

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;

const int N = 2500005, mod = 998244353;

char s[N * 2], t[N];
int ne[N];
int f[N], g[N];

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n, m;
    cin >> m >> n >> s;
    memcpy(s + m, s, m);
    f[0] = 1;
    for (int i = 1; i < n; i++) {
        cin >> t;
        ne[0] = -1;
        for (int i = 1, j = -1; i < m; i++) {
            while (j != -1 && t[j + 1] != t[i]) {
                j = ne[j];
            }
            if (t[j + 1] == t[i]) j++;
            ne[i] = j;
        }
        for (int i = 0; i <= m; i++) {
            if (i) g[i] = (g[i - 1] + f[i]) % mod;
            else g[i] = f[i];
            f[i] = 0;
        }
        for (int i = 0, j = -1; i < m << 1; i++) {
            while (j != -1 && t[j + 1] != s[i]) {
                j = ne[j];
            }
            if (t[j + 1] == s[i]) j++;
            if (j == m - 1) {
                f[i - m + 1] = g[i - m + 1];
                j = ne[j];
            }
        }
    }
    int ret = 0;
    for (int i = 0; i <= m; i++) {
        ret = (ret + f[i]) % mod;
    }
    cout << ret;
    
    return 0;
}

  z 函数做法 AC 代码如下,时间复杂度为 O(nm)

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;

const int N = 2500005, mod = 998244353;

int z[N * 3];
int f[N];

void z_function(string s) {
    int n = s.size();
    for (int i = 1, j = 0; i < n; i++) {
        if (i < j + z[j]) z[i] = min(j + z[j] - i, z[i - j]);
        else z[i] = 0;
        while (i + z[i] < n && s[i + z[i]] == s[z[i]]) {
            z[i]++;
        }
        if (i + z[i] > j + z[j]) j = i;
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n, m;
    string s;
    cin >> m >> n >> s;
    s += s;
    f[0] = 1;
    for (int i = 1; i < n; i++) {
        string t;
        cin >> t;
        z_function(t + s);
        for (int j = 0, s = 0; j <= m; j++) {
            s = (s + f[j]) % mod;
            if (z[m + j] >= m) f[j] = s;
            else f[j] = 0;
        }
    }
    int ret = 0;
    for (int i = 0; i <= m; i++) {
        ret = (ret + f[i]) % mod;
    }
    cout << ret;
    
    return 0;
}

 

参考资料

  Accepted极限代码巅峰赛 lgkm39 提交的代码:https://ac.nowcoder.com/acm/contest/view-submission?submissionId=72962880

posted @   onlyblues  阅读(4)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 单线程的Redis速度为什么快?
· 展开说说关于C#中ORM框架的用法!
· Pantheons:用 TypeScript 打造主流大模型对话的一站式集成库
· SQL Server 2025 AI相关能力初探
· 为什么 退出登录 或 修改密码 无法使 token 失效
历史上的今天:
2023-11-08 城市环路
Web Analytics
点击右上角即可分享
微信分享提示