NTT/FFT与字符串匹配

字符串匹配问题,除了可以用\(KMP\)\(AC\)自动机等,有的还能利用\(NTT/FFT\)实现。
\(Link\)
本题中,对于每个在\(B\)中符合的位置\(i\),都有\(\forall{j\in[0,m-1]},a[j]=b[j+i]\lor{a[j]='*'}\lor{b[j+i]='*'}\)
注意到,\(\forall\)一般用加法实现,\(\lor\)一般用乘法实现。
具体的,在本题中,我们定义\(F('*')=0,F(ch)=ch-'a'+1\)\(x\leftrightarrow{y}\)意为字符\(x\)可以与\(y\)匹配
那么\(x\leftrightarrow{y}=[F(x)*F(y)*(F(x)-F(y))^2=0]\)(一个为\(0\),则等式成立)
\(\forall{j\in[0,m-1]},F(a_j)*F(b_{j+i})*(F(a_j)-F(b_{j+i}))^2=0\)
\(\sum\limits_{j=0}^{m-1}F(a_j)*F(b_{j+i})*(F(a_j)-F(b_{j+i}))^2=0\)
\(a'_i=a_{n-i-1}\),有\(\sum\limits_{j=0}^{m-1}F(a'_{n-j-1})*F(b_{j+i})*(F(a_{n-j-1})-F(b_{j+i}))^2=0\)
转化成卷积形式,依次拆开,有:\(res_{n+i-1}=\sum\limits_{p+q=n+i-1}{a'_p}^3{b_q}+\sum\limits_{p+q=n+i-1}{a'_p}{b_q}^3-2\sum\limits_{p+q=n+i-1}{a'_p}^2{b_q}^2=0\)
那么当\(res_x=0\)时,输出\(x+2-n\)\(+1\)是因为题目下标从\(1\)开始)

Code

#include <bits/stdc++.h>

using namespace std;

const int mod = 998244353, g = 3, G = (mod + 1) / 3;

int n, m, t, k, tot, out[300005], p[1200005], a0[1200005], b0[1200005], a[1200005], b[1200005], res[1200005];

char s1[300005], s2[300005];

int read()
{
	int x = 0; char ch = getchar();
	while (ch < '0' || ch > '9') ch = getchar();
	while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0'; ch = getchar();}
	return x;
}

int qpow(int base, int pw)
{
	int s = 1;
	while (pw)
	{
		if (pw & 1) s = 1ll * s * base % mod;
		base = 1ll * base * base % mod;
		pw >>= 1;
	}
	return s;
}

void g_l(int x)
{
	t = 1, k = 0;
	while (t <= x) t <<= 1, k ++ ;
	for (int i = 0; i < t; i ++ ) p[i] = (p[i >> 1] >> 1) | ((i & 1) << (k - 1));
	return;
}

void NTT(int *a, int o)
{
	for (int i = 0; i < t; i ++ ) if (i < p[i]) swap(a[i], a[p[i]]);
	for (int i = 1; i < t; i <<= 1)
	{
		int wn = qpow(o == 1 ? g : G, (mod - 1) / (i << 1));
		for (int j = 0; j < t; j += (i << 1))
		{
			int w = 1;
			for (int k = 0; k < i; k ++ , w = 1ll * w * wn % mod)
			{
				int p = a[j + k], q = 1ll * w * a[j + k + i] % mod;
				a[j + k] = (p + q) % mod, a[j + k + i] = (p - q + mod) % mod;
			}
		}
	}
	if (o == -1)
	{
		int div = qpow(t, mod - 2);
		for (int i = 0; i < t; i ++ ) a[i] = 1ll * a[i] * div % mod;
	}
	return;
}

int main()
{
	n = read(), m = read(), scanf("%s", s1), scanf("%s", s2);
	for (int i = 0; i < n; i ++ ) a0[i] = (s1[n - i - 1] == '*') ? 0 : s1[n - i - 1] - 'a' + 1;
	for (int i = 0; i < m; i ++ ) b0[i] = (s2[i] == '*') ? 0 : s2[i] - 'a' + 1;
	g_l(n + m);
	for (int i = 0; i < t; i ++ ) a[i] = 1ll * a0[i] * a0[i] % mod * a0[i] % mod, b[i] = b0[i];
	NTT(a, 1), NTT(b, 1);
	for (int i = 0; i < t; i ++ ) res[i] = (res[i] + 1ll * a[i] * b[i] % mod) % mod;
	for (int i = 0; i < t; i ++ ) a[i] = 1ll * a0[i] * a0[i] % mod, b[i] = 1ll * b0[i] * b0[i] % mod;
	NTT(a, 1), NTT(b, 1);
	for (int i = 0; i < t; i ++ ) res[i] = (res[i] - 2ll * a[i] * b[i] % mod + mod * 2ll) % mod;
	for (int i = 0; i < t; i ++ ) a[i] = a0[i], b[i] = 1ll * b0[i] * b0[i] % mod * b0[i] % mod;
	NTT(a, 1), NTT(b, 1);
	for (int i = 0; i < t; i ++ ) res[i] = (res[i] + 1ll * a[i] * b[i] % mod) % mod;
	NTT(res, -1);
	for (int i = n - 1; i <= m - 1; i ++ ) if (!res[i]) out[ ++ tot] = i + 2 - n;
	printf("%d\n", tot); for (int i = 1; i <= tot; i ++ ) printf("%d ", out[i]); putchar('\n');
	return 0;
}
posted @ 2021-05-10 20:51  andysj  阅读(117)  评论(0编辑  收藏  举报