BZOJ4259 残缺的字符串 【fft】

题目

很久很久以前,在你刚刚学习字符串匹配的时候,有两个仅包含小写字母的字符串A和B,其中A串长度为m,B串长度为n。可当你现在再次碰到这两个串时,这两个串已经老化了,每个串都有不同程度的残缺。
你想对这两个串重新进行匹配,其中A为模板串,那么现在问题来了,请回答,对于B的每一个位置i,从这个位置开始连续m个字符形成的子串是否可能与A串完全匹配?

输入格式

第一行包含两个正整数m,n(1<=m<=n<=300000),分别表示A串和B串的长度。
第二行为一个长度为m的字符串A。
第三行为一个长度为n的字符串B。
两个串均仅由小写字母和号组成,其中号表示相应位置已经残缺。

输出格式

第一行包含一个整数k,表示B串中可以完全匹配A串的位置个数。
若k>0,则第二行输出k个正整数,从小到大依次输出每个可以匹配的开头位置(下标从1开始)。

输入样例

3 7

a*b

aebr*ob

输出样例

2

1 5

题解

似乎字符串相关数据结构很难解决这个问题,我们考虑量化匹配关系

我们尝试构造两个多项式,
如果我们能让匹配的位置通过某种运算为0,似乎就能匹配了

如何构造?
如果两个位置匹配,要么这两个位置相等,要么为\(*\)

相等为0,容易想到减法
存在\(*\)就为0,容易想到*就表示0,然后乘起来

就成了这个样子:

\[\sum\limits_{i=1}^{len} (A_i - B_i)^2 * A_i * B_i \]

将其展开成为三个乘积相加的形式
我们将A串翻转,就可以三次fft计算了

#include<iostream>
#include<cstdio>
#include<cmath>
#include<complex>
#include<cstring>
#include<algorithm>
#define eps 1e-9
#define LL long long int
#define Redge(u) for (int k = h[u],to; k; k = ed[k].nxt)
#define REP(i,n) for (int i = 1; i <= (n); i++)
#define BUG(s,n) for (int i = 1; i <= (n); i++) cout<<s[i]<<' '; puts("");
using namespace std;
const int maxn = 1200005,maxm = 100005,INF = 1000000000;
inline int read(){
	int out = 0,flag = 1; char c = getchar();
	while (c < 48 || c > 57){if (c == '-') flag = -1; c = getchar();}
	while (c >= 48 && c <= 57){out = (out << 3) + (out << 1) + c - 48; c = getchar();}
	return out * flag;
}
const double pi = acos(-1);
struct E{
	double r,i;
	E(){}
	E(double a,double b):r(a),i(b){}
	E operator =(const int& b){
		r = b; i = 0;
		return *this;
	}
};
inline E operator +(const E& a,const E& b){
	return E(a.r + b.r,a.i + b.i);
}
inline E operator -(const E& a,const E& b){
	return E(a.r - b.r,a.i - b.i);
}
inline E operator *(const E& a,const E& b){
	return E(a.r * b.r - a.i * b.i,a.r * b.i + b.r * a.i);
}
inline E operator *=(E& a,const E& b){
	return (a = a * b);
}
inline E operator /(E& a,const double& b){
	return E(a.r / b,a.i / b);
}
inline E operator /=(E& a,const double& b){
	return (a = a / b);
}
int n,m,L,R[maxn];
E A[maxn],B[maxn];
void fft(E* a,int f){
	for (int i = 0; i < n; i++) if (i < R[i]) swap(a[i],a[R[i]]);
	for (int i = 1; i < n; i <<= 1){
		E wn(cos(pi / i),f * sin(pi / i));
		for (int j = 0; j < n; j += (i << 1)){
			E w(1,0);
			for (int k = 0; k < i; k++,w *= wn){
				E x = a[j + k],y = w * a[j + k + i];
				a[j + k] = x + y; a[j + k + i] = x - y;
			}
		}
	}
	if (f == -1) for (int i = 0; i < n; i++) a[i] /= n;
}
char P[maxn],T[maxn];
int len,lm,ans[maxn],ansi;
double C[maxn];
void solve(){
	int tmp;
	for (int i = 0; i < n; i++) A[i] = B[i] = 0;
	for (int i = 0; i < len; i++){
		if (P[i] != '*'){
			tmp = (P[i] - 'a' + 1);
			A[i] = tmp * tmp * tmp;
		}
	}
	for (int i = 0; i < lm; i++){
		if (T[i] != '*'){
			tmp = (T[i] - 'a' + 1);
			B[i] = tmp;
		}
	}
	fft(A,1); fft(B,1);
	for (int i = 0; i < n; i++) A[i] *= B[i];
	fft(A,-1);
	for (int i = 0; i < n; i++) C[i] = floor(A[i].r + 0.5);
	
	for (int i = 0; i < n; i++) A[i] = B[i] = 0;
	for (int i = 0; i < len; i++){
		if (P[i] != '*'){
			tmp = (P[i] - 'a' + 1);
			A[i] = tmp * tmp;
		}
	}
	for (int i = 0; i < lm; i++){
		if (T[i] != '*'){
			tmp = (T[i] - 'a' + 1);
			B[i] = tmp * tmp;
		}
	}
	fft(A,1); fft(B,1);
	for (int i = 0; i < n; i++) A[i] *= B[i];
	fft(A,-1);
	for (int i = 0; i < n; i++) C[i] -= 2 * floor(A[i].r + 0.5);
	
	for (int i = 0; i < n; i++) A[i] = B[i] = 0;
	for (int i = 0; i < len; i++){
		if (P[i] != '*'){
			tmp = (P[i] - 'a' + 1);
			A[i] = tmp;
		}
	}
	for (int i = 0; i < lm; i++){
		if (T[i] != '*'){
			tmp = (T[i] - 'a' + 1);
			B[i] = tmp * tmp * tmp;
		}
	}
	fft(A,1); fft(B,1);
	for (int i = 0; i < n; i++) A[i] *= B[i];
	fft(A,-1);
	for (int i = 0; i < n; i++) C[i] += floor(A[i].r + 0.5);
}
int main(){
	len = read(); lm = read();
	scanf("%s",P);
	scanf("%s",T);
	for (int i = 0; i < (len >> 1); i++) swap(P[i],P[len - i - 1]);
	m = len + lm - 2;
	for (n = 1; n <= m; n <<= 1) L++;
	for (int i = 0; i < n; i++) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (L - 1));
	solve();
	for (int i = 0; i <= m; i++){
		if (fabs(C[i]) < eps && i - len + 2 > 0 && i - len + 2 <= lm - len + 1){
			ans[++ansi] = i - len + 2;
		}
	}
	printf("%d\n",ansi);
	for (int i = 1; i <= ansi; i++) printf("%d ",ans[i]);
	return 0;
}

posted @ 2018-04-09 14:08  Mychael  阅读(238)  评论(0编辑  收藏  举报