FFT解决字符串匹配问题

FFT解决字符串匹配问题

朴素的字符串匹配

设有两个字符串 \(a,b\) 长度分别为 \(n,m(n\geq m)\) ,询问字符串 \(b\)\(a\) 中出现了几次。

这是一个最基础的字符串匹配问题,显然我们容易想到 \(O(nm)\) 的暴力匹配方法,同时也可以使用 \(KMP\) 算法在 \(O(n+m)\) 的复杂度下解决该问题。实际上,我们还可以利用 \(FFT\) 来解决这一经典问题,时间复杂度为 \(O(n\log n)\)

构造两个多项式 \(A(x)=a_0+a_1x^1+a_2x^2+\cdots + a_{n-1}x^{n-1}\)\(B(x)=b_0+b_1x^1+b_2x^2+\cdots + b_{m-1}x^{m-1}\) ,其中 \(a_i\) 表示字符串 \(a\) 中下标为 \(i\) 的字符对应的数字(例如我们可以将 \(a\) 映射为 \(1\)\(b\) 映射为 \(2\) ……)。然后我们定义一个匹配多项式 \(C(x)\) ,该函数的第 \(j\)\(c_jx^j\) 的系数 \(c_j\)\(c_j = \sum_{i=0}^{m-1}(a_{j+i}-b_{i})^2\) 。我们发现当这个式子的值为零时,字符串 \(a\) 的子串 \(a_ja_{j+1}\cdots a_{j+m-1}\) 与字符串 \(b\) 匹配成功。

然后我们考虑如何计算这个匹配多项式 \(C(x)\) ,我们需要将其系数转换为卷积形式,因此我们翻转字符串 \(b\) ,设字符串 \(r=reverse(b)\) ,即 \(r_i = b_{m-i-1}\) 。于是:

\[\begin{aligned} c_j &= \sum_{i=0}^{m-1}(a_{j+i}-b_{i})^2\\ &= \sum_{i=0}^{m-1}(a_{j+i}-r_{m-i-1})^2 \\ &= \sum_{i=0}^{m-1}a_{j+i}^2+r_{m-i-1}^2-2a_{j+i}r_{m-i-1} \\ &= \sum_{i=0}^{m-1}a_{j+i}^2 + \sum_{i=0}^{m-1}r_{m-i-1}^2 - 2\sum_{i=0}^{m-1}a_{j+i}r_{m-i-1} \end{aligned} \]

因此 \(c_j\) 被我们转化成了 \(3\) 个和式,我们依次分析。首先 \(\sum_{i=0}^{m-1}a_{j+i}^2\) 这一项显然我们可以 \(O(n)\) 预处理前缀和,然后 \(O(1)\) 查询;然后 \(\sum_{i=0}^{m-1}r_{m-i-1}^2\) 这一项是一个定值;最后 \(\sum_{i=0}^{m-1}a_{j+i}r_{m-i-1}\) 是一个卷积式,我们可以直接构造多项式 \(P(x)=A(x)*B(x)\) ,那么就能够得到该多项式的系数 \(p_j = \sum_{i=0}^ja_ib_{j-i}\) ,进行简单的下标偏移后即可求出 \(c_j\) ,然后通过 \(FFT\) 加速计算,就能在 \(O(n\log n)\) 的复杂度下计算这个卷积。

下面给出一个 \(NTT\) 实现:

#include <bits/stdc++.h>
using namespace std;

constexpr int mod = 998244353;
std::vector<int> rev, roots{ 0, 1 };
int powmod(int a, long long b) {
	int res = 1;
	for (; b; b >>= 1, a = 1ll * a * a % mod)
		if (b & 1)
			res = 1ll * res * a % mod;
	return res;
}
void dft(std::vector<int>& a) {
	int n = a.size();
	if (int(rev.size()) != n) {
		int k = __builtin_ctz(n) - 1;
		rev.resize(n);
		for (int i = 0; i < n; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << k;
	}
	for (int i = 0; i < n; ++i)
		if (rev[i] < i)
			std::swap(a[i], a[rev[i]]);
	if (int(roots.size()) < n) {
		int k = __builtin_ctz(roots.size());
		roots.resize(n);
		while ((1 << k) < n) {
			int e = powmod(3, (mod - 1) >> (k + 1));
			for (int i = 1 << (k - 1); i < (1 << k); ++i) {
				roots[2 * i] = roots[i];
				roots[2 * i + 1] = 1ll * roots[i] * e % mod;
			}
			++k;
		}
	}
	for (int k = 1; k < n; k *= 2) {
		for (int i = 0; i < n; i += 2 * k) {
			for (int j = 0; j < k; ++j) {
				int u = a[i + j];
				int v = 1ll * a[i + j + k] * roots[k + j] % mod;
				int x = u + v;
				if (x >= mod)
					x -= mod;
				a[i + j] = x;
				x = u - v;
				if (x < 0)
					x += mod;
				a[i + j + k] = x;
			}
		}
	}
}
void idft(std::vector<int>& a) {
	int n = a.size();
	std::reverse(a.begin() + 1, a.end());
	dft(a);
	int inv = powmod(n, mod - 2);
	for (int i = 0; i < n; ++i) a[i] = 1ll * a[i] * inv % mod;
}
struct Poly {
	std::vector<int> a;
	Poly() {}
	Poly(int a0) {
		if (a0)
			a = { a0 };
	}
	Poly(const std::vector<int>& a1) : a(a1) {
		while (!a.empty() && !a.back()) a.pop_back();
	}
	int size() const { return a.size(); }
	int operator[](int idx) const {
		if (idx < 0 || idx >= size())
			return 0;
		return a[idx];
	}
	Poly mulxk(int k) const {
		auto b = a;
		b.insert(b.begin(), k, 0);
		return Poly(b);
	}
	Poly modxk(int k) const {
		k = std::min(k, size());
		return Poly(std::vector<int>(a.begin(), a.begin() + k));
	}
	Poly divxk(int k) const {
		if (size() <= k)
			return Poly();
		return Poly(std::vector<int>(a.begin() + k, a.end()));
	}
	friend Poly operator+(const Poly a, const Poly& b) {
		std::vector<int> res(std::max(a.size(), b.size()));
		for (int i = 0; i < int(res.size()); ++i) {
			res[i] = a[i] + b[i];
			if (res[i] >= mod)
				res[i] -= mod;
		}
		return Poly(res);
	}
	friend Poly operator-(const Poly a, const Poly& b) {
		std::vector<int> res(std::max(a.size(), b.size()));
		for (int i = 0; i < int(res.size()); ++i) {
			res[i] = a[i] - b[i];
			if (res[i] < 0)
				res[i] += mod;
		}
		return Poly(res);
	}
	friend Poly operator*(Poly a, Poly b) {
		int sz = 1, tot = a.size() + b.size() - 1;
		while (sz < tot) sz *= 2;
		a.a.resize(sz);
		b.a.resize(sz);
		dft(a.a);
		dft(b.a);
		for (int i = 0; i < sz; ++i) a.a[i] = 1ll * a[i] * b[i] % mod;
		idft(a.a);
		return Poly(a.a);
	}
	Poly& operator+=(Poly b) { return (*this) = (*this) + b; }
	Poly& operator-=(Poly b) { return (*this) = (*this) - b; }
	Poly& operator*=(Poly b) { return (*this) = (*this) * b; }
	Poly deriv() const {  // 求导
		if (a.empty())
			return Poly();
		std::vector<int> res(size() - 1);
		for (int i = 0; i < size() - 1; ++i) res[i] = 1ll * (i + 1) * a[i + 1] % mod;
		return Poly(res);
	}
	Poly integr() const {  // 积分
		if (a.empty())
			return Poly();
		std::vector<int> res(size() + 1);
		for (int i = 0; i < size(); ++i) res[i + 1] = 1ll * a[i] * powmod(i + 1, mod - 2) % mod;
		return Poly(res);
	}
	Poly inv(int m) const {  // 逆
		Poly x(powmod(a[0], mod - 2));
		int k = 1;
		while (k < m) {
			k *= 2;
			x = (x * (2 - modxk(k) * x)).modxk(k);
		}
		return x.modxk(m);
	}
	Poly log(int m) const { return (deriv() * inv(m)).integr().modxk(m); }
	Poly exp(int m) const {
		Poly x(1);
		int k = 1;
		while (k < m) {
			k *= 2;
			x = (x * (1 - x.log(k) + modxk(k))).modxk(k);
		}
		return x.modxk(m);
	}
	Poly mulT(Poly b) const {  // 卷积
		if (b.size() == 0)
			return Poly();
		int n = b.size();
		std::reverse(b.a.begin(), b.a.end());
		return ((*this) * b).divxk(n - 1);
	}
	std::vector<int> eval(std::vector<int> x) const {  // 求值
		if (size() == 0)
			return std::vector<int>(x.size(), 0);
		const int n = std::max(int(x.size()), size());
		std::vector<Poly> q(4 * n);
		std::vector<int> ans(x.size());
		x.resize(n);
		std::function<void(int, int, int)> build = [&](int p, int l, int r) {
			if (r - l == 1) {
				q[p] = std::vector<int>{ 1, (mod - x[l]) % mod };
			}
			else {
				int m = (l + r) / 2;
				build(2 * p, l, m);
				build(2 * p + 1, m, r);
				q[p] = q[2 * p] * q[2 * p + 1];
			}
		};
		build(1, 0, n);
		std::function<void(int, int, int, const Poly&)> work = [&](int p, int l, int r, const Poly& num) {
			if (r - l == 1) {
				if (l < int(ans.size()))
					ans[l] = num[0];
			}
			else {
				int m = (l + r) / 2;
				work(2 * p, l, m, num.mulT(q[2 * p + 1]).modxk(m - l));
				work(2 * p + 1, m, r, num.mulT(q[2 * p]).modxk(r - m));
			}
		};
		work(1, 0, n, mulT(q[1].inv(n)));
		return ans;
	}
};

int main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	cout.tie(nullptr);
	string s1, s2;
	cin >> s1 >> s2;
	reverse(s2.begin(), s2.end());
	int n = s1.length();
	int m = s2.length();
	vector<int> va(n);
	vector<int> vb(m);
	vector<int> pre(n);
	for (int i = 0; i < n; i++)
		va[i] = s1[i] - 'a' + 1;
	for (int i = 0; i < m; i++)
		vb[i] = s2[i] - 'a' + 1;
	int b = 0;
	pre[0] = va[0];
	for (int i = 1; i < n; i++)
		pre[i] = pre[i - 1] + va[i] * va[i];
	for (int i = 0; i < m; i++)
		b += vb[i] * vb[i];
	Poly pa(va);
	Poly pb(vb);
	pa = pa * pb;

	vector<int> res;
	for (int i = 0; i <= n - m; i++) {
		int val = pre[i + m - 1] - (i == 0 ? 0 : pre[i - 1]) + b;
		val -= 2 * pa[i + m - 1];
		if (val == 0)
			res.push_back(i + 1);
	}
	cout << (int)res.size() << '\n'; // 数量
	for (int i = 0; i < (int)res.size(); i++)
		cout << res[i] << " \n"[i == (int)res.size() - 1]; // 匹配位置
	return 0;
}

带通配符的字符串匹配

使用 \(FFT\) 求解字符串匹配问题虽然时间复杂度与 \(KMP\) 算法差距较大,但是具有更强的扩展性,例如下面这题:

设有两个字符串 \(a,b\) 长度分别为 \(n,m(n\geq m)\) ,询问字符串 \(b\)\(a\) 中出现了几次。并且这两个字符串中均含有通配符,可以与任意字符完成匹配。

题目来源:洛谷 \(P4173\) 残缺的字符串 https://www.luogu.com.cn/problem/P4173

我们考虑到,通配符可以匹配任意一个字符,因此如果我们继续采用上方给出的朴素匹配算法,那么通配符这一位的计算结果必然始终为零,因此构造一个新的匹配多项式 \(C(x)\) ,第 \(j\) 项的系数 \(c_j = \sum_{i=0}^{m-1}(a_{j+i}-b_{i})^2a_{j+i}b_i\) ;并且预处理多项式时,通配符的对应数值为零。同样地,我们翻转字符串 \(b\)\(r\) ,然后展开:

\[\begin{aligned} c_j &= \sum_{i=0}^{m-1}(a_{j+i}-b_{i})^2a_{j+i}b_i \\ &= \sum_{i=0}^{m-1}(a_{j+i}-r_{m-i-1})^2a_{j+i}r_{m-i-1} \\ &= \sum_{i=0}^{m-1}a_{j+i}^3r_{m-i-1} - 2a_{j+i}^2r_{m-i-1}^2+a_{j+i}r_{m-i-1}^3 \end{aligned} \]

显然,这是 \(3\) 个卷积式,进行 \(3\) 次卷积即可解决这一问题。

#include <bits/stdc++.h>
using namespace std;

constexpr int mod = 998244353;
std::vector<int> rev, roots{ 0, 1 };
int powmod(int a, long long b) {
	int res = 1;
	for (; b; b >>= 1, a = 1ll * a * a % mod)
		if (b & 1)
			res = 1ll * res * a % mod;
	return res;
}
void dft(std::vector<int>& a) {
	int n = a.size();
	if (int(rev.size()) != n) {
		int k = __builtin_ctz(n) - 1;
		rev.resize(n);
		for (int i = 0; i < n; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << k;
	}
	for (int i = 0; i < n; ++i)
		if (rev[i] < i)
			std::swap(a[i], a[rev[i]]);
	if (int(roots.size()) < n) {
		int k = __builtin_ctz(roots.size());
		roots.resize(n);
		while ((1 << k) < n) {
			int e = powmod(3, (mod - 1) >> (k + 1));
			for (int i = 1 << (k - 1); i < (1 << k); ++i) {
				roots[2 * i] = roots[i];
				roots[2 * i + 1] = 1ll * roots[i] * e % mod;
			}
			++k;
		}
	}
	for (int k = 1; k < n; k *= 2) {
		for (int i = 0; i < n; i += 2 * k) {
			for (int j = 0; j < k; ++j) {
				int u = a[i + j];
				int v = 1ll * a[i + j + k] * roots[k + j] % mod;
				int x = u + v;
				if (x >= mod)
					x -= mod;
				a[i + j] = x;
				x = u - v;
				if (x < 0)
					x += mod;
				a[i + j + k] = x;
			}
		}
	}
}
void idft(std::vector<int>& a) {
	int n = a.size();
	std::reverse(a.begin() + 1, a.end());
	dft(a);
	int inv = powmod(n, mod - 2);
	for (int i = 0; i < n; ++i) a[i] = 1ll * a[i] * inv % mod;
}
struct Poly {
	std::vector<int> a;
	Poly() {}
	Poly(int a0) {
		if (a0)
			a = { a0 };
	}
	Poly(const std::vector<int>& a1) : a(a1) {
		while (!a.empty() && !a.back()) a.pop_back();
	}
	int size() const { return a.size(); }
	int operator[](int idx) const {
		if (idx < 0 || idx >= size())
			return 0;
		return a[idx];
	}
	Poly mulxk(int k) const {
		auto b = a;
		b.insert(b.begin(), k, 0);
		return Poly(b);
	}
	Poly modxk(int k) const {
		k = std::min(k, size());
		return Poly(std::vector<int>(a.begin(), a.begin() + k));
	}
	Poly divxk(int k) const {
		if (size() <= k)
			return Poly();
		return Poly(std::vector<int>(a.begin() + k, a.end()));
	}
	friend Poly operator+(const Poly a, const Poly& b) {
		std::vector<int> res(std::max(a.size(), b.size()));
		for (int i = 0; i < int(res.size()); ++i) {
			res[i] = a[i] + b[i];
			if (res[i] >= mod)
				res[i] -= mod;
		}
		return Poly(res);
	}
	friend Poly operator-(const Poly a, const Poly& b) {
		std::vector<int> res(std::max(a.size(), b.size()));
		for (int i = 0; i < int(res.size()); ++i) {
			res[i] = a[i] - b[i];
			if (res[i] < 0)
				res[i] += mod;
		}
		return Poly(res);
	}
	friend Poly operator*(Poly a, Poly b) {
		int sz = 1, tot = a.size() + b.size() - 1;
		while (sz < tot) sz *= 2;
		a.a.resize(sz);
		b.a.resize(sz);
		dft(a.a);
		dft(b.a);
		for (int i = 0; i < sz; ++i) a.a[i] = 1ll * a[i] * b[i] % mod;
		idft(a.a);
		return Poly(a.a);
	}
	Poly& operator+=(Poly b) { return (*this) = (*this) + b; }
	Poly& operator-=(Poly b) { return (*this) = (*this) - b; }
	Poly& operator*=(Poly b) { return (*this) = (*this) * b; }
	Poly deriv() const {  // 求导
		if (a.empty())
			return Poly();
		std::vector<int> res(size() - 1);
		for (int i = 0; i < size() - 1; ++i) res[i] = 1ll * (i + 1) * a[i + 1] % mod;
		return Poly(res);
	}
	Poly integr() const {  // 积分
		if (a.empty())
			return Poly();
		std::vector<int> res(size() + 1);
		for (int i = 0; i < size(); ++i) res[i + 1] = 1ll * a[i] * powmod(i + 1, mod - 2) % mod;
		return Poly(res);
	}
	Poly inv(int m) const {  // 逆
		Poly x(powmod(a[0], mod - 2));
		int k = 1;
		while (k < m) {
			k *= 2;
			x = (x * (2 - modxk(k) * x)).modxk(k);
		}
		return x.modxk(m);
	}
	Poly log(int m) const { return (deriv() * inv(m)).integr().modxk(m); }
	Poly exp(int m) const {
		Poly x(1);
		int k = 1;
		while (k < m) {
			k *= 2;
			x = (x * (1 - x.log(k) + modxk(k))).modxk(k);
		}
		return x.modxk(m);
	}
	Poly mulT(Poly b) const {  // 卷积
		if (b.size() == 0)
			return Poly();
		int n = b.size();
		std::reverse(b.a.begin(), b.a.end());
		return ((*this) * b).divxk(n - 1);
	}
	std::vector<int> eval(std::vector<int> x) const {  // 求值
		if (size() == 0)
			return std::vector<int>(x.size(), 0);
		const int n = std::max(int(x.size()), size());
		std::vector<Poly> q(4 * n);
		std::vector<int> ans(x.size());
		x.resize(n);
		std::function<void(int, int, int)> build = [&](int p, int l, int r) {
			if (r - l == 1) {
				q[p] = std::vector<int>{ 1, (mod - x[l]) % mod };
			}
			else {
				int m = (l + r) / 2;
				build(2 * p, l, m);
				build(2 * p + 1, m, r);
				q[p] = q[2 * p] * q[2 * p + 1];
			}
		};
		build(1, 0, n);
		std::function<void(int, int, int, const Poly&)> work = [&](int p, int l, int r, const Poly& num) {
			if (r - l == 1) {
				if (l < int(ans.size()))
					ans[l] = num[0];
			}
			else {
				int m = (l + r) / 2;
				work(2 * p, l, m, num.mulT(q[2 * p + 1]).modxk(m - l));
				work(2 * p + 1, m, r, num.mulT(q[2 * p]).modxk(r - m));
			}
		};
		work(1, 0, n, mulT(q[1].inv(n)));
		return ans;
	}
};

int main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	cout.tie(nullptr);
	int n, m;
	string s1, s2;
	cin >> m >> n >> s2 >> s1;
	reverse(s2.begin(), s2.end());
	vector<int> va1(n);
	vector<int> va2(n);
	vector<int> va3(n);
	vector<int> vb1(m);
	vector<int> vb2(m);
	vector<int> vb3(m);
	for (int i = 0; i < n; i++) {
		if (s1[i] == '*') {
			va1[i] = va2[i] = va3[i] = 0;
		} else {
			va1[i] = s1[i] - 'a' + 1;
			va2[i] = (s1[i] - 'a' + 1) * va1[i];
			va3[i] = (s1[i] - 'a' + 1) * va2[i];
		}
	}
	for (int i = 0; i < m; i++) {
		if (s2[i] == '*') {
			vb1[i] = vb2[i] = vb3[i] = 0;
		} else {
			vb1[i] = s2[i] - 'a' + 1;
			vb2[i] = (s2[i] - 'a' + 1) * vb1[i];
			vb3[i] = (s2[i] - 'a' + 1) * vb2[i];
		}
	}
	vector<int> ans(n - m + 1, 0);
	vector<int> res;
	Poly pa, pb;
	pa = Poly(va1), pb = Poly(vb3);
	pa *= pb;
	for (int i = 0; i <= n - m; i++)
		ans[i] += pa[i + m - 1];
	pa = Poly(va3), pb = Poly(vb1);
	pa *= pb;
	for (int i = 0; i <= n - m; i++)
		ans[i] += pa[i + m - 1];
	pa = Poly(va2), pb = Poly(vb2);
	pa *= pb;
	for (int i = 0; i <= n - m; i++)
		ans[i] -= 2 * pa[i + m - 1];
	for (int i = 0; i <= n - m; i++) {
		if (ans[i] == 0) {
			res.push_back(i + 1);
		}
	}
	cout << (int)res.size() << '\n';
	for (int i = 0; i < (int)res.size(); i++)
		cout << res[i] << " \n"[i == (int)res.size() - 1];
	return 0;
}

字符集较小时的字符串匹配

设有两个字符串 \(a,b\) 长度分别为 \(n,m(n\geq m)\) ,询问字符串 \(b\)\(a\) 中出现了几次。并且,这两个字符串均由字符集 \(S\) 构成。

当这个字符集 \(S\) 很小时,我们利用 \(FFT\) 解决匹配问题又会多出一种新的方法。我们换一个角度重新思考字符串匹配问题,构造匹配多项式 \(C(x)\) ,该多项式的第 \(j\) 项系数 \(c_j = \sum^m_{i=0}[a_{i+j}=b_i]\)

那么,两个字符串 \(a,b\) 在位置 \(j\) 成功匹配的充要条件就是 \(c_j=m\) 。但是这个式子的计算在字符较多时比较困难,因为难以控制布尔运算式的状态。但是如果字符集较小时,我们就能用 \(0,1\) 两种状态来实现这个思路。不妨假设这个字符集仅含有两个字母 \(A,B\) ,我们通过这个例子来介绍一种利用 \(01\) 多项式的方法。

思路非常简单,我们首先将两个字符串中的字符 \(A\) 置为 \(1\)\(B\) 置为 \(0\) ,计算出匹配多项式 \(c_j = \sum^m_{i=0}[a_{i+j}=b_i]\) 每一项的值。由于只有 \(0,1\) 两种状态,因此 \(c_j=\sum^m_{i=0}a_{i+j}b_i\) ,将 \(b\) 翻转后即可卷积,此处略过推导。

然后将两个字符串中的字符 \(B\) 置为 \(1\) ,字符 \(A\) 置为 \(0\) ,再进行卷积,并将结果加到匹配多项式的系数 \(c_j\) 上。最后遍历一次多项式 \(C\) ,若 \(c_j=m\) 则成功匹配。因此我们可以通过枚举字符集中的每一个元素实现这一思路,时间复杂度 \(O(|S|n\log n)\)

下面同样给出一个实现:

#include <bits/stdc++.h>
using namespace std;

constexpr int mod = 998244353;
std::vector<int> rev, roots{ 0, 1 };
int powmod(int a, long long b) {
	int res = 1;
	for (; b; b >>= 1, a = 1ll * a * a % mod)
		if (b & 1)
			res = 1ll * res * a % mod;
	return res;
}
void dft(std::vector<int>& a) {
	int n = a.size();
	if (int(rev.size()) != n) {
		int k = __builtin_ctz(n) - 1;
		rev.resize(n);
		for (int i = 0; i < n; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << k;
	}
	for (int i = 0; i < n; ++i)
		if (rev[i] < i)
			std::swap(a[i], a[rev[i]]);
	if (int(roots.size()) < n) {
		int k = __builtin_ctz(roots.size());
		roots.resize(n);
		while ((1 << k) < n) {
			int e = powmod(3, (mod - 1) >> (k + 1));
			for (int i = 1 << (k - 1); i < (1 << k); ++i) {
				roots[2 * i] = roots[i];
				roots[2 * i + 1] = 1ll * roots[i] * e % mod;
			}
			++k;
		}
	}
	for (int k = 1; k < n; k *= 2) {
		for (int i = 0; i < n; i += 2 * k) {
			for (int j = 0; j < k; ++j) {
				int u = a[i + j];
				int v = 1ll * a[i + j + k] * roots[k + j] % mod;
				int x = u + v;
				if (x >= mod)
					x -= mod;
				a[i + j] = x;
				x = u - v;
				if (x < 0)
					x += mod;
				a[i + j + k] = x;
			}
		}
	}
}
void idft(std::vector<int>& a) {
	int n = a.size();
	std::reverse(a.begin() + 1, a.end());
	dft(a);
	int inv = powmod(n, mod - 2);
	for (int i = 0; i < n; ++i) a[i] = 1ll * a[i] * inv % mod;
}
struct Poly {
	std::vector<int> a;
	Poly() {}
	Poly(int a0) {
		if (a0)
			a = { a0 };
	}
	Poly(const std::vector<int>& a1) : a(a1) {
		while (!a.empty() && !a.back()) a.pop_back();
	}
	int size() const { return a.size(); }
	int operator[](int idx) const {
		if (idx < 0 || idx >= size())
			return 0;
		return a[idx];
	}
	Poly mulxk(int k) const {
		auto b = a;
		b.insert(b.begin(), k, 0);
		return Poly(b);
	}
	Poly modxk(int k) const {
		k = std::min(k, size());
		return Poly(std::vector<int>(a.begin(), a.begin() + k));
	}
	Poly divxk(int k) const {
		if (size() <= k)
			return Poly();
		return Poly(std::vector<int>(a.begin() + k, a.end()));
	}
	friend Poly operator+(const Poly a, const Poly& b) {
		std::vector<int> res(std::max(a.size(), b.size()));
		for (int i = 0; i < int(res.size()); ++i) {
			res[i] = a[i] + b[i];
			if (res[i] >= mod)
				res[i] -= mod;
		}
		return Poly(res);
	}
	friend Poly operator-(const Poly a, const Poly& b) {
		std::vector<int> res(std::max(a.size(), b.size()));
		for (int i = 0; i < int(res.size()); ++i) {
			res[i] = a[i] - b[i];
			if (res[i] < 0)
				res[i] += mod;
		}
		return Poly(res);
	}
	friend Poly operator*(Poly a, Poly b) {
		int sz = 1, tot = a.size() + b.size() - 1;
		while (sz < tot) sz *= 2;
		a.a.resize(sz);
		b.a.resize(sz);
		dft(a.a);
		dft(b.a);
		for (int i = 0; i < sz; ++i) a.a[i] = 1ll * a[i] * b[i] % mod;
		idft(a.a);
		return Poly(a.a);
	}
	Poly& operator+=(Poly b) { return (*this) = (*this) + b; }
	Poly& operator-=(Poly b) { return (*this) = (*this) - b; }
	Poly& operator*=(Poly b) { return (*this) = (*this) * b; }
	Poly deriv() const {  // 求导
		if (a.empty())
			return Poly();
		std::vector<int> res(size() - 1);
		for (int i = 0; i < size() - 1; ++i) res[i] = 1ll * (i + 1) * a[i + 1] % mod;
		return Poly(res);
	}
	Poly integr() const {  // 积分
		if (a.empty())
			return Poly();
		std::vector<int> res(size() + 1);
		for (int i = 0; i < size(); ++i) res[i + 1] = 1ll * a[i] * powmod(i + 1, mod - 2) % mod;
		return Poly(res);
	}
	Poly inv(int m) const {  // 逆
		Poly x(powmod(a[0], mod - 2));
		int k = 1;
		while (k < m) {
			k *= 2;
			x = (x * (2 - modxk(k) * x)).modxk(k);
		}
		return x.modxk(m);
	}
	Poly log(int m) const { return (deriv() * inv(m)).integr().modxk(m); }
	Poly exp(int m) const {
		Poly x(1);
		int k = 1;
		while (k < m) {
			k *= 2;
			x = (x * (1 - x.log(k) + modxk(k))).modxk(k);
		}
		return x.modxk(m);
	}
	Poly mulT(Poly b) const {  // 卷积
		if (b.size() == 0)
			return Poly();
		int n = b.size();
		std::reverse(b.a.begin(), b.a.end());
		return ((*this) * b).divxk(n - 1);
	}
	std::vector<int> eval(std::vector<int> x) const {  // 求值
		if (size() == 0)
			return std::vector<int>(x.size(), 0);
		const int n = std::max(int(x.size()), size());
		std::vector<Poly> q(4 * n);
		std::vector<int> ans(x.size());
		x.resize(n);
		std::function<void(int, int, int)> build = [&](int p, int l, int r) {
			if (r - l == 1) {
				q[p] = std::vector<int>{ 1, (mod - x[l]) % mod };
			}
			else {
				int m = (l + r) / 2;
				build(2 * p, l, m);
				build(2 * p + 1, m, r);
				q[p] = q[2 * p] * q[2 * p + 1];
			}
		};
		build(1, 0, n);
		std::function<void(int, int, int, const Poly&)> work = [&](int p, int l, int r, const Poly& num) {
			if (r - l == 1) {
				if (l < int(ans.size()))
					ans[l] = num[0];
			}
			else {
				int m = (l + r) / 2;
				work(2 * p, l, m, num.mulT(q[2 * p + 1]).modxk(m - l));
				work(2 * p + 1, m, r, num.mulT(q[2 * p]).modxk(r - m));
			}
		};
		work(1, 0, n, mulT(q[1].inv(n)));
		return ans;
	}
};

int main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	cout.tie(nullptr);
	string s1, s2;
	cin >> s1 >> s2;
	reverse(s2.begin(), s2.end());
	int n = s1.length();
	int m = s2.length();
	vector<char> st;
	vector<int> cnt(n - m + 1, 0);
	for (auto c : s1)
		st.push_back(c);
	for (auto c : s2)
		st.push_back(c);
	sort(st.begin(), st.end());
	st.erase(unique(st.begin(), st.end()), st.end());

	auto sol = [&](string s1, string s2, char ch) {
		vector<int> va(n, 0);
		vector<int> vb(m, 0);
		for (int i = 0; i < n; i++) {
			if (s1[i] == ch) {
				va[i] = 1;
			}
		}
		for (int i = 0; i < m; i++) {
			if (s2[i] == ch) {
				vb[i] = 1;
			}
		}
		Poly pa(va);
		Poly pb(vb);
		pa *= pb;
		for (int i = m - 1; i <= n - 1; i++)
			cnt[i - m + 1] += pa[i];
	};

	for (auto c : st)
		sol(s1, s2, c);

	vector<int> res;
	for (int i = 0; i <= n - m; i++) {
		if (cnt[i] == m) {
			res.push_back(i + 1);
		}
	}
	cout << (int)res.size() << '\n';
	for (int i = 0; i < res.size(); i++)
		cout << res[i] << " \n"[i == (int)res.size() - 1];
	return 0;
}

题意:给定两个只由 \(A,C,G,T\) 构成的字符串 \(S,T\) 和一个门限值 \(k\) ,询问字符串 \(T\)\(S\) 中出现的次数(即匹配次数)。我们定义两个字符串在位置 \(j\) 是匹配的当且仅当,对于 \(T\) 中的任意一个字符 \(T_i\) 都至少有一个字符 \(S_k\) 满足 \(S_k=T_i\),其中 \(j+i-l\leq k\leq j+i+k\)

分析:这题由于字符集较小(大小为 \(4\) ),因此直接考虑使用上方给出的方法,枚举字符集中的元素求解。但是本题引入了一个门限值的概念,基础的字符串匹配问题实际上询问的是本题中 \(k=0\) 时的特例,那么对于这个拓展问题我们如何解决?

本题中的门限值实际上就是将一个字符的可匹配范围向两端均延申了 \(k\) 个单位,于是我们自然可以想到:只需要将主串中的待匹配字符均向两端延申 \(k\) 个单位,然后再用模式串进行匹配即可。以本题样例为例,假设我们正在枚举字符 \(A\)

本题的两个字符串 AGCAATTCATACAT 先预处理成 \(01\) 多项式:10011000101010

然后将主串中的待匹配字符向外延申 \(k(k=1)\) 个单位:1111110111 ,然后与 1010 进行卷积即可。

AC代码https://codeforces.com/contest/528/submission/94462497

与此题类似的还有 Gym101667H Rock Paper Scissors。

CF827E Rusty String

题意:给定一个由 \(3\) 种字符 \(V,K,?\) 构成的字符串 \(s\),其中 \(?\) 可以表示 \(V,K\) 中的任意一种,询问该字符串所有可能的循环节长度。

分析:我们考虑求出所有不符合要求的循环节长度,如果一个长度为 \(d\) 的循环节不满足题意,那么必然存在 \(s_i\neq s_{i+d}\) ,因此构造一个多项式 \(C(x)\) ,其系数 \(c_j = \sum_{i=0}^{n-1}[s_i\neq s_{i+j}]\) ,如果 \(c_j>0\) 那就说明长度为 \(j\) 的循环节不满足题意。

由于字符集较小,我们仍然考虑使用 \(01\) 多项式的方法解决本题,需要注意的一点是本题中的 \(?\) 可以近似地看作通配符,因此直接置为零。

\[\begin{aligned} c_j&=\sum_{i=0}^{n-1}[s_i\neq s_{i+j}] \\ &= \sum_{i=0}^{n-1}s_is'_{n-1-i-j}\\ \end{aligned} \]

上式中,\(s_i\) 表示字符串 \(s\) 中的 \(V\) 全部置 \(1\) 的多项式系数,\(s'_i\) 表示字符串 \(s\) 翻转后将 \(K\) 全部置 \(1\) 的多项式系数。

然后你就会写出一个样例都过不去的代码,这是因为本题中的 \(?\) 并不等价于通配符,例如样例中长度为 \(2\) 的循环节会出现以下情况:

V ? ? V K

N N V ? ? V K (N表示空位)

如果 \(?\) 为通配符,那么本题确实可以成功匹配,但是本题中 \(?\) 并不能同时为 \(V\)\(K\) 。本例中出现了 \(s_0=s_2=s_4\) 的情况,因此存在矛盾。

考虑一下为什么会出现这个问题,这是因为我们只判断了 \(s_i\)\(s_{i+d}\) 的关系,但是这个关系并不具有传递性,\(s_i=s_{i+d}=s_{i+2d}=\cdots = s_{i+kd}\) 不能递推得到 \(d\) 是合法的循环节长度。利用这个性质来完善我们的方法:对于一个长度为 \(d\) 的循环节,如果它是合法的,那么所有长度为 \(kd(k>1)\) 的循环节也必须合法。因此我们可以利用埃氏筛对每个位置进行可行性检测。

这里的实现还需要注意一个 \(wa\) 点,长度为 \(d\) 的循环节不仅对应卷积后的 \(c_{n-1-d}\) ,同时也对应了 \(c_{n-1+d}\) 。举个例子进行说明:

设字符串 \(s=K?V\) ,那么经过处理后得到:\(a=0,0,1;\ b=0,0,1\) ;卷积后的结果为:\(c=0,0,0,0,1\) ,如果你在判断 \(d=1\) 时漏判了 \(c_{n-1+d}\) 就会导致答案错误,因为 \(c_0\)\(c_4\) 实际上都是 \(d=2\) 的卷积系数。

AC代码https://codeforces.com/contest/827/submission/94561373

posted @ 2020-10-03 18:41  st1vdy  阅读(775)  评论(0编辑  收藏  举报