FFT/NTT字符串模糊匹配
因为FFT精度问题太离谱了,所以墙裂推荐用NTT
首先考虑精确匹配:https://www.acwing.com/problem/content/833/
假设我们有短串\(s1\)(长度为\(n\)),长串\(s2\)(长度为\(m\))
我们定义字符差
若\(c(x,y) = 0\),表明\(s1\)的第\(x\)个字符与\(s2\)的第\(y\)个字符匹配,再定义
为\(s2\)子串的字符差之和,这个子串长为\(n\)并且以下标\(x\)为结尾,若\(F(x) = 0\),则表明这个子串与\(s1\)完全匹配,但这样可能会将\(ab\)与\(ba\)算为完全匹配,因此我们考虑将\(F(x)\)换个表达式
这样若\(F(x) = 0\),则表明这个子串与之完全匹配,将其暴力拆解
其中\(\sum_{i = 0}^{n - 1}s1(i)^2\)和\(\sum_{i = 0}^{n - 1}s2(x-n+i+1)^2\)都可以用前缀和解决,关键是\(\sum_{i = 0}^{n - 1}2s1(i)s2(x-n+i+1)\),我们将\(s1\)翻转,可得\(s1'(x-n+i+1)=s1(i)\),即
可以发现能用NTT啦!因此
当\(F(x)=0\)时,表明完全匹配
AC代码:
不开O2会T
#include <unordered_map>
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <string>
#include <stack>
#include <deque>
#include <queue>
#include <cmath>
#include <map>
#include <set>
using namespace std;
#pragma GCC optimize(2)
typedef pair<int, int> PII;
typedef pair<double, int> PDI;
//typedef __int128 int128;
typedef long long ll;
typedef unsigned long long ull;
const int INF = 0x3f3f3f3f;
const int N = 1e7 + 10, M = 4e7 + 10;
const int base = 1e9;
const int P = 131;
const int mod = 998244353;
const double eps = 1e-9;
const double PI = acos(-1.0);
int n, m, tot, bit;
char s1[N], s2[N];
ll S[N], a[N], b[N];
int R[N];
ll ksm(ll a, ll b)
{
ll res = 1 % mod;
while (b)
{
if (b & 1)
res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
void inif(int n)
{
tot = 1, bit = 0;
while (tot <= n)
tot *= 2, ++bit;
for (int i = 0; i <= tot; ++i)
R[i] = (R[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}
void NTT(ll f[], int total, int type)
{
for (int i = 0; i < total; ++i)
if (i < R[i])
swap(f[i], f[R[i]]);
for (int tot = 2; tot <= total; tot *= 2)
{
ll w1 = ksm(type == 1 ? 3 : 332748118, (mod - 1) / tot);
//332748118为 3 在模 998244353 的逆元
for (int pos = 0; pos < total; pos += tot)
{
ll w = 1;
for (int i = pos; i < pos + tot / 2; ++i, w = w * w1 % mod)
{
int x = f[i];
int y = w * f[i + tot / 2] % mod;
f[i] = (x + y) % mod;
f[i + tot / 2] = (x - y + mod) % mod;
}
}
}
if (type == -1)
{
int inv = ksm(tot, mod - 2);
for (int i = 0; i <= n + m; ++i)
a[i] = a[i] * inv % mod;
}
}
int main()
{
scanf("%d%s%d%s", &n, &s1, &m, &s2);
for (int i = 0; i < n; ++i)
a[i] = s1[i] - 'a' + 1;
for (int i = 0; i < m; ++i)
b[i] = s2[i] - 'a' + 1;
reverse(a, a + n);
ll sum = 0;
for (int i = 0; i < n; ++i)
sum = (sum + a[i] * a[i] % mod) % mod;
S[0] = b[0] * b[0];
for (int i = 1; i < m; ++i)
S[i] = (S[i - 1] + b[i] * b[i] % mod) % mod;
inif(n + m);
NTT(a, tot, 1), NTT(b, tot, 1);
for (int i = 0; i < tot; ++i)
a[i] = a[i] * b[i] % mod;
NTT(a, tot, -1);
for (int x = n - 1; x < m; ++x)
{
double P = (sum + S[x] - S[x - n] - 2 * a[x]) % mod;
if (P == 0)
printf("%d ", x - n + 1);
}
return 0;
}
接着我们考虑模糊匹配,即有通配符的情况:https://www.luogu.com.cn/problem/P4173
设通配符的值为0,重新定义字符差
发现会完美解决问题,依然暴力拆解
当\(F(x)=0\)时,表明完全匹配
AC代码:
#include <unordered_map>
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <string>
#include <stack>
#include <deque>
#include <queue>
#include <cmath>
#include <map>
#include <set>
using namespace std;
typedef pair<int, int> PII;
typedef pair<double, int> PDI;
//typedef __int128 int128;
typedef long long ll;
typedef unsigned long long ull;
const int INF = 0x3f3f3f3f;
const int N = 1e7 + 10, M = 4e7 + 10;
const int base = 1e9;
const int P = 131;
const int mod = 998244353;
const double eps = 1e-9;
const double PI = acos(-1.0);
int n, m;
int A[N], B[N];
char s1[N], s2[N];
int R[N], ans[N];
int tot, bit, pos;
ll a[N], b[N], p[N];
ll ksm(ll a, ll b)
{
ll res = 1 % mod;
while (b)
{
if (b & 1)
res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
void inif(int n)
{
tot = 1, bit = 0;
while (tot <= n)
tot *= 2, ++bit;
for (int i = 0; i <= tot; ++i)
R[i] = (R[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}
void NTT(ll f[], int total, int type)
{
for (int i = 0; i < total; ++i)
if (i < R[i])
swap(f[i], f[R[i]]);
for (int tot = 2; tot <= total; tot *= 2)
{
ll w1 = ksm(type == 1 ? 3 : 332748118, (mod - 1) / tot);
//332748118为 3 在模 998244353 的逆元
for (int pos = 0; pos < total; pos += tot)
{
ll w = 1;
for (int i = pos; i < pos + tot / 2; ++i, w = w * w1 % mod)
{
int x = f[i];
int y = w * f[i + tot / 2] % mod;
f[i] = (x + y) % mod;
f[i + tot / 2] = (x - y + mod) % mod;
}
}
}
if (type == -1)
{
int inv = ksm(tot, mod - 2);
for (int i = 0; i <= n + m; ++i)
a[i] = a[i] * inv % mod;
}
}
int main()
{
scanf("%d%d%s%s", &n, &m, &s1, &s2);
reverse(s1, s1 + n);
for (int i = 0; i < n; ++i)
A[i] = s1[i] == '*' ? 0 : s1[i] - 'a' + 1;
for (int i = 0; i < m; ++i)
B[i] = s2[i] == '*' ? 0 : s2[i] - 'a' + 1;
inif(n + m);
//A[i]^3 B[i]
for (int i = 0; i < tot; ++i)
a[i] = A[i] * A[i] * A[i];
for (int i = 0; i < tot; ++i)
b[i] = B[i];
NTT(a, tot, 1), NTT(b, tot, 1);
for (int i = 0; i < tot; ++i)
p[i] = (p[i] + a[i] * b[i]) % mod;
//A[i] B[i]^3
for (int i = 0; i < tot; ++i)
a[i] = A[i];
for (int i = 0; i < tot; ++i)
b[i] = B[i] * B[i] * B[i];
NTT(a, tot, 1), NTT(b, tot, 1);
for (int i = 0; i < tot; ++i)
p[i] = (p[i] + a[i] * b[i]) % mod;
//A[i]^2 B[i]^2
for (int i = 0; i < tot; ++i)
a[i] = A[i] * A[i];
for (int i = 0; i < tot; ++i)
b[i] = B[i] * B[i];
NTT(a, tot, 1), NTT(b, tot, 1);
for (int i = 0; i < tot; ++i)
p[i] = (p[i] - 2 * a[i] * b[i] + mod) % mod;
NTT(p, tot, -1);
for (int i = n - 1; i < m; ++i)
if (p[i] == 0)
ans[++pos] = i - n + 2;
printf("%d\n", pos);
for (int i = 1; i <= pos; ++i)
printf("%d ", ans[i]);
return 0;
}
然后是杭电多校让我知道了这个知识点
HDU6975:https://acm.hdu.edu.cn/showproblem.php?pid=6975
因为字符只包含0-9和,首先不考虑通配符,我们可以枚举0-9,将每个子串在0-9情况下的匹配数算出来,以8为例,将所有为8的地方值设为1,其他地方值设为0,则对单个字符的匹配数有
求出每个子串的匹配数后就可以考虑通配符了,其实通配符匹配数=\(s1\)通配符数+\(s2\)子串通配符数-\(s1\)和\(s2\)子串相同位置的通配符数,前缀和加卷积即可求出
AC代码:
#include <unordered_map>
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <string>
#include <stack>
#include <deque>
#include <queue>
#include <cmath>
#include <map>
#include <set>
using namespace std;
typedef pair<int, int> PII;
typedef pair<double, int> PDI;
//typedef __int128 int128;
typedef long long ll;
typedef unsigned long long ull;
const int INF = 0x3f3f3f3f;
const int N = 1e6 + 10, M = 4e7 + 10;
const int base = 1e9;
const int P = 131;
const int mod = 998244353;
const double eps = 1e-9;
const double PI = acos(-1.0);
FILE *fp;
int n, m, tot, bit;
char s1[N], s2[N];
int R[N], ans[N];
ll a[N], b[N], f[N], S[N];
ll ksm(ll a, ll b)
{
ll res = 1 % mod;
while (b)
{
if (b & 1)
res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
void inif(int n)
{
memset(s1, 0, sizeof(s1));
memset(s2, 0, sizeof(s2));
memset(ans, 0, sizeof(ans));
memset(f, 0, sizeof(f));
tot = 1, bit = 0;
while (tot <= n)
tot *= 2, ++bit;
for (int i = 0; i <= tot; ++i)
R[i] = (R[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}
void NTT(ll f[], int total, int type)
{
for (int i = 0; i < total; ++i)
if (i < R[i])
swap(f[i], f[R[i]]);
for (int tot = 2; tot <= total; tot *= 2)
{
ll w1 = ksm(type == 1 ? 3 : 332748118, (mod - 1) / tot);
//332748118? 3 ?? 998244353 ???
for (int pos = 0; pos < total; pos += tot)
{
ll w = 1;
for (int i = pos; i < pos + tot / 2; ++i, w = w * w1 % mod)
{
int x = f[i];
int y = w * f[i + tot / 2] % mod;
f[i] = (x + y) % mod;
f[i + tot / 2] = (x - y + mod) % mod;
}
}
}
if (type == -1)
{
int inv = ksm(tot, mod - 2);
for (int i = 0; i <= n + m; ++i)
f[i] = f[i] * inv % mod;
}
}
void get(char c, int type)
{
for (int i = 0; i < tot; ++i)
a[i] = s1[i] == c;
for (int i = 0; i < tot; ++i)
b[i] = s2[i] == c;
NTT(a, tot, 1), NTT(b, tot, 1);
for (int i = 0; i < tot; ++i)
{
if (type == 1)
f[i] = (f[i] + a[i] * b[i] % mod) % mod;
else
f[i] = (f[i] - a[i] * b[i] % mod + mod) % mod;
}
}
int main()
{
int T;
scanf("%d", &T);
while (T--)
{
scanf("%d%d", &m, &n);
inif(n + m);
scanf("%s%s", s2, s1);
reverse(s1, s1 + n);
for (char c = '0'; c <= '9'; ++c)
get(c, 1);
get('*', -1);
NTT(f, tot, -1);
ll sum = 0;
for (int i = 0; i < n; ++i)
sum += s1[i] == '*';
S[0] = s2[0] == '*';
for (int i = 1; i < m; ++i)
S[i] = (S[i - 1] + (s2[i] == '*')) % mod;
for (int i = 0; i < tot; ++i)
{
if (i >= n)
f[i] = (f[i] + sum + S[i] - S[i - n] + mod) % mod;
else
f[i] = (f[i] + sum + S[i]) % mod;
}
for (int i = n - 1; i < m; ++i)
++ans[n - f[i]];
for (int i = 0; i <= n; ++i)
{
if (i)
ans[i] += ans[i - 1];
printf("%d\n", ans[i]);
}
}
return 0;
}