题解 match
MD 先把我是 SB 重复十遍
考虑暴力怎么打
把所有东西都塞到状态里面可以有一个 \(O(n^4)\) 的暴力 DP
令 \(f_{i, j, k, l, c}\) 为 \(s\) 匹配到了 \(i\),\(t\) 历史最长匹配长度为 \(j\),是在 \([k-j+1, k]\) 匹配到的,
当前 \(t\) 匹配位置为 \(l\),\(t\) 失配的那个字符为 \(c\),若这个字符是未确定的记为 \(\tt f\)
可以过 60 pts
然后观察这个 DP
发现中间两维好像可以直接在原串上枚举子串
那么现在要强制一个 \(s_{l, r}\) 是能匹配的最长前缀且做 \(t\) 的前缀合法
合法可以一遍 KMP check(因为只能匹配这么多,所以若长度到了 \(r-l+1\) 可以强制失配)
最长前缀就是要求统计后面随便填的方案数时其长度 \(> r-l+1\) 的前缀不能是 \(s\) 的子串
上面是后一个条件的完整表述
于是可以 \(O(n^3)\),精细实现加全剪枝貌似可以过
然后考虑优化到 \(O(n^2)\)
官方题解给的什么 shit 一个证明都不带有的
回去考虑给的假算法
发现实际上要求跑 KMP 的时候一旦跳了 border 就一定要跳到 0 而且跳到 0 之后 \(s_{i+1}\) 不能和 \(t_1\) 匹配
那么考虑这个限制
定义失配串为:考虑 \(S[l,r]\) 及其前缀在原串中的所有出现位置,记为集合 \(Z\),存在某个出现位置不被 \(Z\) 中元素包含的,可以通过 \(S[l,r]\) 的一个前缀添加某个字符得到的串
啊对定义是从 zxy 博里粘过来的
这东西说人话就是一个加上最后一个字符之后必须失配,不可能换个匹配位置继续匹配之类的一个串
而且它去掉最后一个字符之后还是 \(s_{l, r}\) 的前缀
它有啥性质呢?
失配串是 \(s\) 的一个子串
用 \(s_{l, r}\) 和 \(s\) 匹配的过程中早晚会遇到它
它不是 \(s_{l, r}\) 的子串,而且是 \(s_{l, r}\) 的前缀
所以匹配到 失配串去掉最后一个字符 的位置上时,失配串长度为 \(len-1\) 的前缀恰好与 \(s_{l, r}\) 长度为 \(len-1\) 的前缀匹配
注意此时原串匹配到失配串的位置,等价于失配串在和 \(s_{l, r}\) 做匹配
那么加上最后那个字符会使其失配
若失配串有长度 \(>0\) 的 border,那一定是 \(s_{l, r}\) 的一个前缀,那么跳 border 会跳到那个位置
那么就一定不合法了
所以判定一个 \(s_{l, r}\) 是否合法的条件之一是其对应的所有失配串最长 border 都为 0
艹对着结论猜原因折腾了我一上午
那么现在就是要维护出能独立于 \(Z\) 中元素出现的最长 boader 不为 0 的失配串了
考虑在字典树上 dfs,同时维护当前每个最长 border 不为 0 的失配串的出现次数
若当前在点 \(u\),要向下转移到 \(v\),设走的转移边上的字母为 \(c\)
那么 \(Z\) 集合扩大了,可能需要从原集合中删去一些失配串
这些串一定是 \(v\) 的后缀,所以它们是 \(x\) 的 border (包括 \(x\))的 \(c\) 方向的儿子
注意此时一个失配串需要被考虑当且仅当它不被 \(Z\) 集合中的任何一个元素包含
而如果跳 border 时发现某一时刻跳到的点在 \(c\) 方向的儿子恰好在根到 \(u\) 的链上(是 \(s_{l, r}\) 的子串)
那么它及其 border 都已经在 \(Z\) 集合中了,不必再考虑
刚才删去了不再合法的失配串,现在考虑需要加入的失配串
我们从 \(u\) 到 \(v\) 走了 \(c\) 的转移边,那么 \(v\) 的所有兄弟代表的点都是合法的失配串
同时,\(v\) 的所有儿子同样满足上述定义,也是合法的失配串
然后这样做的复杂度是什么呢?
发现这个过程不强于对 trie 树上的每一条链都跑了一次 KMP
那么这样就是 \(O(n^2)\) 的
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 2010
#define ll long long
#define ull unsigned long long
//#define int long long
int n, m;
char s[N];
const ll mod=998244353;
inline void md(ll& a, ll b) {a+=b; a=a>=mod?a-mod:a;}
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
namespace force{
ll ans;
char t[N];
int f[N], g[N], nxt[N];
// void match() {
// int cur=0;
// for (int i=1; i<=n; ++i) {
// if (cur<m&&s[i]==t[cur+1]) ++cur;
// else cur=0;
// g[i]=cur;
// }
// }
// void kmp() {
// nxt[1]=0;
// for (int i=2,j=0; i<=m; ++i) {
// while (j&&t[j+1]!=t[i]) j=nxt[j];
// if (t[j+1]==t[i]) ++j;
// nxt[i]=j;
// }
// for (int i=1,j=0; i<=n; ++i) {
// while (j&&(j==m||s[i]!=t[j+1])) j=nxt[j];
// if (s[i]==t[j+1]) ++j;
// f[i]=j;
// }
// }
bool kmp2() {
nxt[1]=0;
for (int i=2,j=0; i<=m; ++i) {
while (j&&t[j+1]!=t[i]) j=nxt[j];
if (t[j+1]==t[i]) ++j;
nxt[i]=j;
}
for (int i=1,j=0; i<=n; ++i) {
bool flag=0;
while (j&&(j==m||s[i]!=t[j+1])) j=nxt[j], flag=1;
if (flag&&j) return 0;
if (flag&&!j&&s[i]==t[j+1]) return 0;
if (s[i]==t[j+1]) ++j;
}
return 1;
}
void dfs(int u) {
if (u>m) {
// match(); kmp();
// for (int i=1; i<=n; ++i) if (f[i]!=g[i]) return ;
if (kmp2()) ++ans;
return ;
}
for (int i=0; i<5; ++i) t[u]=i, dfs(u+1);
}
void solve() {
dfs(1);
printf("%lld\n", ans*qpow(qpow(5, m), mod-2)%mod);
}
}
// namespace task1{
// int now;
// ll f[2][105][105][105][6];
// void solve() {
// // cout<<double(sizeof(f))/1000/1000<<endl;
// f[now][0][0][0][5]=1;
// for (int i=0; i<n; ++i,now^=1) {
// memset(f[now^1], 0, sizeof(f[now^1]));
// for (int j=0; j<=i; ++j) {
// for (int k=j; k<=i; ++k) {
// for (int l=0; l<=i; ++l) {
// for (int c=0; c<=5; ++c) if (f[now][j][k][l][c]) {
// if (l!=j) {
// if (加上 t[l+1] 后仍能匹配) {
// }
// else if (跳 nxt 会一直跳到 0) {
// }
// }
// else {
// if (c!=5) {
// if (加上 c 仍能匹配) {
// }
// else {
// j 不变,失配
// }
// }
// else {
// for (int t=0; t<5; ++t) {
// if (加上 c 仍能匹配) {
// }
// else {
// j 不变,失配
// }
// }
// }
// }
// }
// }
// }
// }
// }
// }
// }
namespace task1{
int now;
ll f[2][105][105][105][6], ans;
int nxt[105][105][105], top[105][105][105][5];
void kmp(char* t, int len, int* nxt, int (*top)[5]) {
nxt[1]=0;
for (int i=2,j=0; i<=m; ++i) {
while (j&&t[j+1]!=t[i]) j=nxt[j];
if (t[j+1]==t[i]) ++j;
nxt[i]=j;
}
for (int c=0; c<5; ++c) top[1][c]=(len>=2&&t[2]==c);
for (int i=2; i<=m; ++i) {
for (int c=0; c<5; ++c) {
if (t[i+1]!=c) top[i][c]=top[nxt[i]][c];
else top[i][c]=i+1;
}
}
}
bool check(int l, int r, int pos, char c) {
// cout<<"check: "<<l<<' '<<r<<' '<<pos<<endl;
char* t=&s[l-1];
pos=nxt[l][r][pos];
// while (pos&&(t[pos+1]!=c)) pos=nxt[l][r][pos];
pos=top[l][r][pos][c];
// cout<<"pos: "<<pos<<' '<<char('a'+t[pos+1])<<' '<<char('a'+c)<<endl;
if (pos) return 0;
if (!pos&&t[pos+1]==c) return 0;
return 1;
}
void solve() {
// cout<<double(sizeof(f))/1000/1000<<endl;
for (int l=1; l<=n; ++l)
for (int r=l; r<=n; ++r)
kmp((&s[l])-1, r-l+1, nxt[l][r], top[l][r]);
f[now][0][0][0][5]=1;
for (int i=0; i<n; ++i,now^=1) {
memset(f[now^1], 0, sizeof(f[now^1]));
for (int j=0; j<=min(i, m); ++j) {
for (int k=j; k<=i; ++k) {
for (int l=0; l<=j; ++l) {
for (int c=0; c<=5; ++c) if (f[now][j][k][l][c]) {
// cout<<"tr: "<<i<<' '<<j<<' '<<k<<' '<<l<<' '<<char('a'+c)<<' '<<f[now][j][k][l][c]<<endl;
if (l!=j) {
// if (加上 t[l+1] 后仍能匹配) {
if (s[k-j+1+l]==s[i+1]) {
md(f[now^1][j][k][l+1][c], f[now][j][k][l][c]);
}
// else if (跳 nxt 会一直跳到 0) {
else if (check(k-j+1, k, l, s[i+1])) {
md(f[now^1][j][k][0][c], f[now][j][k][l][c]);
}
}
else {
if (c!=5) {
// if (加上 c 仍能匹配) {
if (j!=m&&s[i+1]==c) {
md(f[now^1][j+1][i+1][l+1][5], f[now][j][k][l][c]);
}
// j 不变,失配
else if (!l||check(k-j+1, k, l, s[i+1])) {
md(f[now^1][j][k][0][c], f[now][j][k][l][c]);
}
}
else {
// cout<<"pos1: "<<i<<' '<<j<<' '<<k<<' '<<l<<' '<<char('a'+c)<<' '<<f[now][j][k][l][c]<<endl;
if (j!=m) {
for (int t=0; t<5; ++t) {
// cout<<"t: "<<t<<endl;
// if (加上 t 仍能匹配) {
if (j!=m&&s[i+1]==t) {
// cout<<"tr1"<<endl;
md(f[now^1][j+1][i+1][l+1][5], f[now][j][k][l][c]); //, cout<<1<<endl;
}
// j 不变,失配
else if (!l||check(k-j+1, k, l, s[i+1])) {
// cout<<"tr2"<<endl;
md(f[now^1][j][k][0][t], f[now][j][k][l][c]);
}
}
}
else {
assert(c==5);
if (!l||check(k-j+1, k, l, s[i+1])) {
md(f[now^1][j][k][0][c], f[now][j][k][l][c]);
}
}
}
}
}
}
}
}
}
for (int j=0; j<=m; ++j) {
for (int k=j; k<=n; ++k) {
for (int l=0; l<=j; ++l) {
for (int c=0; c<=5; ++c) if (f[now][j][k][l][c]) {
// printf("f[n][%d][%d][%d][%c]=%lld\n", j, k, l, 'a'+c, f[now][j][k][l][c]);
if (c==5) ans=(ans+qpow(5, m-j))%mod;
else ans=(ans+qpow(5, m-j-1))%mod;
}
}
}
}
// printf("%lld\n", (ans%mod+mod)%mod);
printf("%lld\n", (ans*qpow(qpow(5, m), mod-2)%mod+mod)%mod);
}
}
namespace task{
ll pw[N], ans;
ull h[N][N];
const ull base=13131;
unordered_map<ull, int> mp;
int tr[N*N][5], suf[N*N], nxt[N*N], cnt[N*N], now[N*N], len[N*N], tot, live;
void dfs(int u) {
// cout<<"u: "<<u<<endl;
for (int i=0,v; i<5; ++i) if (v=tr[u][i]) {
// cout<<"v: "<<v<<endl;
nxt[v]=nxt[u];
while (nxt[v]&&(suf[nxt[v]]!=i))
nxt[v]=nxt[nxt[v]];
if (u&&suf[nxt[v]]==i) nxt[v]=tr[nxt[v]][suf[nxt[v]]];
// cout<<nxt[v]<<endl;
}
for (int i=0,v; i<5; ++i) if (v=tr[u][i]) {
assert(!now[v]);
now[v]+=cnt[v];
live+=(bool)nxt[v];
}
for (int i=0,v; i<5; ++i) if (v=tr[u][i]) {
if (!(now[v]-=cnt[v])) live-=(bool)nxt[v];
for (int t=nxt[u]; t&&suf[t]!=i; t=nxt[t])
if (tr[t][i] && !(now[tr[t][i]]-=cnt[v])) live-=(bool)nxt[v];
suf[u]=i; dfs(v);
for (int t=nxt[u]; t&&suf[t]!=i; t=nxt[t])
if (tr[t][i]) live+=((now[tr[t][i]]+=cnt[v])==cnt[v] && (bool)nxt[v]);
if ((now[v]+=cnt[v])==cnt[v]) live+=(bool)nxt[v];
}
for (int i=0,v; i<5; ++i) if (v=tr[u][i]) {
now[v]-=cnt[v];
live-=(bool)nxt[v];
assert(!now[v]);
}
// cout<<"at "<<u<<" live="<<live<<endl;
bool vis=!live; int deg=5;
for (int i=0,v; i<5; ++i) if (v=tr[u][i])
vis&=(nxt[tr[u][i]]==0), --deg;
if (vis&&len[u]<=m) {
// cout<<"at: "<<u<<" add "<<(len[u]==m?cnt[u]:cnt[u]*deg%mod*qpow(5, m-len[u]-1))<<endl;
if (len[u]==m) ans=(ans+1)%mod;
else ans=(ans+deg%mod*pw[m-len[u]-1])%mod;
}
}
void solve() {
pw[0]=1;
for (int i=1; i<=m; ++i) pw[i]=pw[i-1]*5%mod;
for (int i=1; i<=n; ++i) {
ull tem=0;
for (int j=i; j<=n; ++j)
++mp[h[i][j]=tem=tem*base+'a'+s[j]];
}
for (int i=1; i<=n; ++i) {
for (int j=i,p=0,*t; j<=n; ++j,p=*t) {
t=&tr[p][s[j]];
if (!*t) *t=++tot;
// cout<<"ij: "<<i<<' '<<j<<' '<<h[i][j]<<endl;
cnt[*t]=mp[h[i][j]];
len[*t]=j-i+1;
}
}
cnt[0]=1; dfs(0);
// cout<<"nxt: "; for (int i=1; i<=tot; ++i) cout<<nxt[i]<<' '; cout<<endl;
// cout<<"cnt: "; for (int i=1; i<=tot; ++i) cout<<cnt[i]<<' '; cout<<endl;
// printf("%lld\n", ans);
printf("%lld\n", (ans*qpow(qpow(5, m), mod-2)%mod+mod)%mod);
}
}
signed main()
{
freopen("match.in", "r", stdin);
freopen("match.out", "w", stdout);
scanf("%d%d%s", &n, &m, s+1);
for (int i=1; i<=n; ++i) s[i]-='a';
// force::solve();
// task1::solve();
task::solve();
return 0;
}