题解 match

传送门

MD 先把我是 SB 重复十遍

考虑暴力怎么打
把所有东西都塞到状态里面可以有一个 \(O(n^4)\) 的暴力 DP
\(f_{i, j, k, l, c}\)\(s\) 匹配到了 \(i\)\(t\) 历史最长匹配长度为 \(j\),是在 \([k-j+1, k]\) 匹配到的,
当前 \(t\) 匹配位置为 \(l\)\(t\) 失配的那个字符为 \(c\),若这个字符是未确定的记为 \(\tt f\)
可以过 60 pts

然后观察这个 DP
发现中间两维好像可以直接在原串上枚举子串
那么现在要强制一个 \(s_{l, r}\) 是能匹配的最长前缀且做 \(t\) 的前缀合法
合法可以一遍 KMP check(因为只能匹配这么多,所以若长度到了 \(r-l+1\) 可以强制失配)
最长前缀就是要求统计后面随便填的方案数时其长度 \(> r-l+1\) 的前缀不能是 \(s\) 的子串
image
上面是后一个条件的完整表述
于是可以 \(O(n^3)\),精细实现加全剪枝貌似可以过

然后考虑优化到 \(O(n^2)\)
官方题解给的什么 shit 一个证明都不带有的
回去考虑给的假算法
发现实际上要求跑 KMP 的时候一旦跳了 border 就一定要跳到 0 而且跳到 0 之后 \(s_{i+1}\) 不能和 \(t_1\) 匹配
那么考虑这个限制
定义失配串为:考虑 \(S[l,r]\) 及其前缀在原串中的所有出现位置,记为集合 \(Z\),存在某个出现位置不被 \(Z\) 中元素包含的,可以通过 \(S[l,r]\) 的一个前缀添加某个字符得到的串
啊对定义是从 zxy 博里粘过来的
这东西说人话就是一个加上最后一个字符之后必须失配,不可能换个匹配位置继续匹配之类的一个串
而且它去掉最后一个字符之后还是 \(s_{l, r}\) 的前缀
它有啥性质呢?

失配串是 \(s\) 的一个子串
\(s_{l, r}\)\(s\) 匹配的过程中早晚会遇到它
它不是 \(s_{l, r}\) 的子串,而且是 \(s_{l, r}\) 的前缀
所以匹配到 失配串去掉最后一个字符 的位置上时,失配串长度为 \(len-1\) 的前缀恰好与 \(s_{l, r}\) 长度为 \(len-1\) 的前缀匹配
注意此时原串匹配到失配串的位置,等价于失配串在和 \(s_{l, r}\) 做匹配
那么加上最后那个字符会使其失配
若失配串有长度 \(>0\) 的 border,那一定是 \(s_{l, r}\) 的一个前缀,那么跳 border 会跳到那个位置
那么就一定不合法了
所以判定一个 \(s_{l, r}\) 是否合法的条件之一是其对应的所有失配串最长 border 都为 0
艹对着结论猜原因折腾了我一上午

那么现在就是要维护出能独立于 \(Z\) 中元素出现的最长 boader 不为 0 的失配串了
考虑在字典树上 dfs,同时维护当前每个最长 border 不为 0 的失配串的出现次数
若当前在点 \(u\),要向下转移到 \(v\),设走的转移边上的字母为 \(c\)
那么 \(Z\) 集合扩大了,可能需要从原集合中删去一些失配串
这些串一定是 \(v\) 的后缀,所以它们是 \(x\) 的 border (包括 \(x\))的 \(c\) 方向的儿子
注意此时一个失配串需要被考虑当且仅当它不被 \(Z\) 集合中的任何一个元素包含
而如果跳 border 时发现某一时刻跳到的点在 \(c\) 方向的儿子恰好在根到 \(u\) 的链上(是 \(s_{l, r}\) 的子串)
那么它及其 border 都已经在 \(Z\) 集合中了,不必再考虑

刚才删去了不再合法的失配串,现在考虑需要加入的失配串
我们从 \(u\)\(v\) 走了 \(c\) 的转移边,那么 \(v\) 的所有兄弟代表的点都是合法的失配串
同时,\(v\) 的所有儿子同样满足上述定义,也是合法的失配串

然后这样做的复杂度是什么呢?
发现这个过程不强于对 trie 树上的每一条链都跑了一次 KMP
那么这样就是 \(O(n^2)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 2010
#define ll long long
#define ull unsigned long long
//#define int long long

int n, m;
char s[N];
const ll mod=998244353;
inline void md(ll& a, ll b) {a+=b; a=a>=mod?a-mod:a;}
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}

namespace force{
	ll ans;
	char t[N];
	int f[N], g[N], nxt[N];
	// void match() {
	// 	int cur=0;
	// 	for (int i=1; i<=n; ++i) {
	// 		if (cur<m&&s[i]==t[cur+1]) ++cur;
	// 		else cur=0;
	// 		g[i]=cur;
	// 	}
	// }
	// void kmp() {
	// 	nxt[1]=0;
	// 	for (int i=2,j=0; i<=m; ++i) {
	// 		while (j&&t[j+1]!=t[i]) j=nxt[j];
	// 		if (t[j+1]==t[i]) ++j;
	// 		nxt[i]=j;
	// 	}
	// 	for (int i=1,j=0; i<=n; ++i) {
	// 		while (j&&(j==m||s[i]!=t[j+1])) j=nxt[j];
	// 		if (s[i]==t[j+1]) ++j;
	// 		f[i]=j;
	// 	}
	// }
	bool kmp2() {
		nxt[1]=0;
		for (int i=2,j=0; i<=m; ++i) {
			while (j&&t[j+1]!=t[i]) j=nxt[j];
			if (t[j+1]==t[i]) ++j;
			nxt[i]=j;
		}
		for (int i=1,j=0; i<=n; ++i) {
			bool flag=0;
			while (j&&(j==m||s[i]!=t[j+1])) j=nxt[j], flag=1;
			if (flag&&j) return 0;
			if (flag&&!j&&s[i]==t[j+1]) return 0;
			if (s[i]==t[j+1]) ++j;
		}
		return 1;
	}
	void dfs(int u) {
		if (u>m) {
			// match(); kmp();
			// for (int i=1; i<=n; ++i) if (f[i]!=g[i]) return ;
			if (kmp2()) ++ans;
			return ;
		}
		for (int i=0; i<5; ++i) t[u]=i, dfs(u+1);
	}
	void solve() {
		dfs(1);
		printf("%lld\n", ans*qpow(qpow(5, m), mod-2)%mod);
	}
}

// namespace task1{
// 	int now;
// 	ll f[2][105][105][105][6];
// 	void solve() {
// 		// cout<<double(sizeof(f))/1000/1000<<endl;
// 		f[now][0][0][0][5]=1;
// 		for (int i=0; i<n; ++i,now^=1) {
// 			memset(f[now^1], 0, sizeof(f[now^1]));
// 			for (int j=0; j<=i; ++j) {
// 				for (int k=j; k<=i; ++k) {
// 					for (int l=0; l<=i; ++l) {
// 						for (int c=0; c<=5; ++c) if (f[now][j][k][l][c]) {
// 							if (l!=j) {
// 								if (加上 t[l+1] 后仍能匹配) {

// 								}
// 								else if (跳 nxt 会一直跳到 0) {
									
// 								}
// 							}
// 							else {
// 								if (c!=5) {
// 									if (加上 c 仍能匹配) {

// 									}
// 									else {
// 										j 不变,失配
// 									}
// 								}
// 								else {
// 									for (int t=0; t<5; ++t) {
// 										if (加上 c 仍能匹配) {

// 										}
// 										else {
// 											j 不变,失配
// 										}
// 									}
// 								}
// 							}
// 						}
// 					}
// 				}
// 			}
// 		}
// 	}
// }

namespace task1{
	int now;
	ll f[2][105][105][105][6], ans;
	int nxt[105][105][105], top[105][105][105][5];
	void kmp(char* t, int len, int* nxt, int (*top)[5]) {
		nxt[1]=0;
		for (int i=2,j=0; i<=m; ++i) {
			while (j&&t[j+1]!=t[i]) j=nxt[j];
			if (t[j+1]==t[i]) ++j;
			nxt[i]=j;
		}
		for (int c=0; c<5; ++c) top[1][c]=(len>=2&&t[2]==c);
		for (int i=2; i<=m; ++i) {
			for (int c=0; c<5; ++c) {
				if (t[i+1]!=c) top[i][c]=top[nxt[i]][c];
				else top[i][c]=i+1;
			}
		}
	}
	bool check(int l, int r, int pos, char c) {
		// cout<<"check: "<<l<<' '<<r<<' '<<pos<<endl;
		char* t=&s[l-1];
		pos=nxt[l][r][pos];
		// while (pos&&(t[pos+1]!=c)) pos=nxt[l][r][pos];
		pos=top[l][r][pos][c];
		// cout<<"pos: "<<pos<<' '<<char('a'+t[pos+1])<<' '<<char('a'+c)<<endl;
		if (pos) return 0;
		if (!pos&&t[pos+1]==c) return 0;
		return 1;
	}
	void solve() {
		// cout<<double(sizeof(f))/1000/1000<<endl;
		for (int l=1; l<=n; ++l)
			for (int r=l; r<=n; ++r)
				kmp((&s[l])-1, r-l+1, nxt[l][r], top[l][r]);
		f[now][0][0][0][5]=1;
		for (int i=0; i<n; ++i,now^=1) {
			memset(f[now^1], 0, sizeof(f[now^1]));
			for (int j=0; j<=min(i, m); ++j) {
				for (int k=j; k<=i; ++k) {
					for (int l=0; l<=j; ++l) {
						for (int c=0; c<=5; ++c) if (f[now][j][k][l][c]) {
							// cout<<"tr: "<<i<<' '<<j<<' '<<k<<' '<<l<<' '<<char('a'+c)<<' '<<f[now][j][k][l][c]<<endl;
							if (l!=j) {
								// if (加上 t[l+1] 后仍能匹配) {
								if (s[k-j+1+l]==s[i+1]) {
									md(f[now^1][j][k][l+1][c], f[now][j][k][l][c]);
								}
								// else if (跳 nxt 会一直跳到 0) {
								else if (check(k-j+1, k, l, s[i+1])) {
									md(f[now^1][j][k][0][c], f[now][j][k][l][c]);
								}
							}
							else {
								if (c!=5) {
									// if (加上 c 仍能匹配) {
									if (j!=m&&s[i+1]==c) {
										md(f[now^1][j+1][i+1][l+1][5], f[now][j][k][l][c]);
									}
									// j 不变,失配
									else if (!l||check(k-j+1, k, l, s[i+1])) {
										md(f[now^1][j][k][0][c], f[now][j][k][l][c]);
									}
								}
								else {
									// cout<<"pos1: "<<i<<' '<<j<<' '<<k<<' '<<l<<' '<<char('a'+c)<<' '<<f[now][j][k][l][c]<<endl;
									if (j!=m) {
										for (int t=0; t<5; ++t) {
											// cout<<"t: "<<t<<endl;
											// if (加上 t 仍能匹配) {
											if (j!=m&&s[i+1]==t) {
												// cout<<"tr1"<<endl;
												md(f[now^1][j+1][i+1][l+1][5], f[now][j][k][l][c]); //, cout<<1<<endl;
											}
											// j 不变,失配
											else if (!l||check(k-j+1, k, l, s[i+1])) {
												// cout<<"tr2"<<endl;
												md(f[now^1][j][k][0][t], f[now][j][k][l][c]);
											}
										}
									}
									else {
										assert(c==5);
										if (!l||check(k-j+1, k, l, s[i+1])) {
											md(f[now^1][j][k][0][c], f[now][j][k][l][c]);
										}
									}
								}
							}
						}
					}
				}
			}
		}
		for (int j=0; j<=m; ++j) {
			for (int k=j; k<=n; ++k) {
				for (int l=0; l<=j; ++l) {
					for (int c=0; c<=5; ++c) if (f[now][j][k][l][c]) {
						// printf("f[n][%d][%d][%d][%c]=%lld\n", j, k, l, 'a'+c, f[now][j][k][l][c]);
						if (c==5) ans=(ans+qpow(5, m-j))%mod;
						else ans=(ans+qpow(5, m-j-1))%mod;
					}
				}
			}
		}
		// printf("%lld\n", (ans%mod+mod)%mod);
		printf("%lld\n", (ans*qpow(qpow(5, m), mod-2)%mod+mod)%mod);
	}
}

namespace task{
	ll pw[N], ans;
	ull h[N][N];
	const ull base=13131;
	unordered_map<ull, int> mp;
	int tr[N*N][5], suf[N*N], nxt[N*N], cnt[N*N], now[N*N], len[N*N], tot, live;
	void dfs(int u) {
		// cout<<"u: "<<u<<endl;
		for (int i=0,v; i<5; ++i) if (v=tr[u][i]) {
			// cout<<"v: "<<v<<endl;
			nxt[v]=nxt[u];
			while (nxt[v]&&(suf[nxt[v]]!=i))
				nxt[v]=nxt[nxt[v]];
			if (u&&suf[nxt[v]]==i) nxt[v]=tr[nxt[v]][suf[nxt[v]]];
			// cout<<nxt[v]<<endl;
		}
		for (int i=0,v; i<5; ++i) if (v=tr[u][i]) {
			assert(!now[v]);
			now[v]+=cnt[v];
			live+=(bool)nxt[v];
		}
		for (int i=0,v; i<5; ++i) if (v=tr[u][i]) {
			if (!(now[v]-=cnt[v])) live-=(bool)nxt[v];
			for (int t=nxt[u]; t&&suf[t]!=i; t=nxt[t])
				if (tr[t][i] && !(now[tr[t][i]]-=cnt[v])) live-=(bool)nxt[v];
			suf[u]=i; dfs(v);
			for (int t=nxt[u]; t&&suf[t]!=i; t=nxt[t])
				if (tr[t][i]) live+=((now[tr[t][i]]+=cnt[v])==cnt[v] && (bool)nxt[v]);
			if ((now[v]+=cnt[v])==cnt[v]) live+=(bool)nxt[v];
		}
		for (int i=0,v; i<5; ++i) if (v=tr[u][i]) {
			now[v]-=cnt[v];
			live-=(bool)nxt[v];
			assert(!now[v]);
		}
		// cout<<"at "<<u<<" live="<<live<<endl;
		bool vis=!live; int deg=5;
		for (int i=0,v; i<5; ++i) if (v=tr[u][i])
			vis&=(nxt[tr[u][i]]==0), --deg;
		if (vis&&len[u]<=m) {
			// cout<<"at: "<<u<<" add "<<(len[u]==m?cnt[u]:cnt[u]*deg%mod*qpow(5, m-len[u]-1))<<endl;
			if (len[u]==m) ans=(ans+1)%mod;
			else ans=(ans+deg%mod*pw[m-len[u]-1])%mod;
		}
	}
	void solve() {
		pw[0]=1;
		for (int i=1; i<=m; ++i) pw[i]=pw[i-1]*5%mod;
		for (int i=1; i<=n; ++i) {
			ull tem=0;
			for (int j=i; j<=n; ++j)
				++mp[h[i][j]=tem=tem*base+'a'+s[j]];
		}
		for (int i=1; i<=n; ++i) {
			for (int j=i,p=0,*t; j<=n; ++j,p=*t) {
				t=&tr[p][s[j]];
				if (!*t) *t=++tot;
				// cout<<"ij: "<<i<<' '<<j<<' '<<h[i][j]<<endl;
				cnt[*t]=mp[h[i][j]];
				len[*t]=j-i+1;
			}
		}
		cnt[0]=1; dfs(0);
		// cout<<"nxt: "; for (int i=1; i<=tot; ++i) cout<<nxt[i]<<' '; cout<<endl;
		// cout<<"cnt: "; for (int i=1; i<=tot; ++i) cout<<cnt[i]<<' '; cout<<endl;
		// printf("%lld\n", ans);
		printf("%lld\n", (ans*qpow(qpow(5, m), mod-2)%mod+mod)%mod);
	}
}

signed main()
{
	freopen("match.in", "r", stdin);
	freopen("match.out", "w", stdout);

	scanf("%d%d%s", &n, &m, s+1);
	for (int i=1; i<=n; ++i) s[i]-='a';
	// force::solve();
	// task1::solve();
	task::solve();
	
	return 0;
}
posted @ 2022-06-18 17:27  Administrator-09  阅读(2)  评论(0编辑  收藏  举报