NTT/FFT与字符串匹配
字符串匹配问题,除了可以用\(KMP\)、\(AC\)自动机等,有的还能利用\(NTT/FFT\)实现。
\(Link\)
本题中,对于每个在\(B\)中符合的位置\(i\),都有\(\forall{j\in[0,m-1]},a[j]=b[j+i]\lor{a[j]='*'}\lor{b[j+i]='*'}\)
注意到,\(\forall\)一般用加法实现,\(\lor\)一般用乘法实现。
具体的,在本题中,我们定义\(F('*')=0,F(ch)=ch-'a'+1\),\(x\leftrightarrow{y}\)意为字符\(x\)可以与\(y\)匹配
那么\(x\leftrightarrow{y}=[F(x)*F(y)*(F(x)-F(y))^2=0]\)(一个为\(0\),则等式成立)
则\(\forall{j\in[0,m-1]},F(a_j)*F(b_{j+i})*(F(a_j)-F(b_{j+i}))^2=0\)
即\(\sum\limits_{j=0}^{m-1}F(a_j)*F(b_{j+i})*(F(a_j)-F(b_{j+i}))^2=0\)
令\(a'_i=a_{n-i-1}\),有\(\sum\limits_{j=0}^{m-1}F(a'_{n-j-1})*F(b_{j+i})*(F(a_{n-j-1})-F(b_{j+i}))^2=0\)
转化成卷积形式,依次拆开,有:\(res_{n+i-1}=\sum\limits_{p+q=n+i-1}{a'_p}^3{b_q}+\sum\limits_{p+q=n+i-1}{a'_p}{b_q}^3-2\sum\limits_{p+q=n+i-1}{a'_p}^2{b_q}^2=0\)
那么当\(res_x=0\)时,输出\(x+2-n\)(\(+1\)是因为题目下标从\(1\)开始)
Code
#include <bits/stdc++.h>
using namespace std;
const int mod = 998244353, g = 3, G = (mod + 1) / 3;
int n, m, t, k, tot, out[300005], p[1200005], a0[1200005], b0[1200005], a[1200005], b[1200005], res[1200005];
char s1[300005], s2[300005];
int read()
{
int x = 0; char ch = getchar();
while (ch < '0' || ch > '9') ch = getchar();
while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0'; ch = getchar();}
return x;
}
int qpow(int base, int pw)
{
int s = 1;
while (pw)
{
if (pw & 1) s = 1ll * s * base % mod;
base = 1ll * base * base % mod;
pw >>= 1;
}
return s;
}
void g_l(int x)
{
t = 1, k = 0;
while (t <= x) t <<= 1, k ++ ;
for (int i = 0; i < t; i ++ ) p[i] = (p[i >> 1] >> 1) | ((i & 1) << (k - 1));
return;
}
void NTT(int *a, int o)
{
for (int i = 0; i < t; i ++ ) if (i < p[i]) swap(a[i], a[p[i]]);
for (int i = 1; i < t; i <<= 1)
{
int wn = qpow(o == 1 ? g : G, (mod - 1) / (i << 1));
for (int j = 0; j < t; j += (i << 1))
{
int w = 1;
for (int k = 0; k < i; k ++ , w = 1ll * w * wn % mod)
{
int p = a[j + k], q = 1ll * w * a[j + k + i] % mod;
a[j + k] = (p + q) % mod, a[j + k + i] = (p - q + mod) % mod;
}
}
}
if (o == -1)
{
int div = qpow(t, mod - 2);
for (int i = 0; i < t; i ++ ) a[i] = 1ll * a[i] * div % mod;
}
return;
}
int main()
{
n = read(), m = read(), scanf("%s", s1), scanf("%s", s2);
for (int i = 0; i < n; i ++ ) a0[i] = (s1[n - i - 1] == '*') ? 0 : s1[n - i - 1] - 'a' + 1;
for (int i = 0; i < m; i ++ ) b0[i] = (s2[i] == '*') ? 0 : s2[i] - 'a' + 1;
g_l(n + m);
for (int i = 0; i < t; i ++ ) a[i] = 1ll * a0[i] * a0[i] % mod * a0[i] % mod, b[i] = b0[i];
NTT(a, 1), NTT(b, 1);
for (int i = 0; i < t; i ++ ) res[i] = (res[i] + 1ll * a[i] * b[i] % mod) % mod;
for (int i = 0; i < t; i ++ ) a[i] = 1ll * a0[i] * a0[i] % mod, b[i] = 1ll * b0[i] * b0[i] % mod;
NTT(a, 1), NTT(b, 1);
for (int i = 0; i < t; i ++ ) res[i] = (res[i] - 2ll * a[i] * b[i] % mod + mod * 2ll) % mod;
for (int i = 0; i < t; i ++ ) a[i] = a0[i], b[i] = 1ll * b0[i] * b0[i] % mod * b0[i] % mod;
NTT(a, 1), NTT(b, 1);
for (int i = 0; i < t; i ++ ) res[i] = (res[i] + 1ll * a[i] * b[i] % mod) % mod;
NTT(res, -1);
for (int i = n - 1; i <= m - 1; i ++ ) if (!res[i]) out[ ++ tot] = i + 2 - n;
printf("%d\n", tot); for (int i = 1; i <= tot; i ++ ) printf("%d ", out[i]); putchar('\n');
return 0;
}