【多项式】[LGP4173] 残缺的字符串

【多项式】[LGP4173] 残缺的字符串

题意

给定两个有通配符的字符串,跑字符串匹配。

思路

肯定不能用 kmp (不要问为什么)。

\(A_{1:m}\) 为模式串,\(B_{1:n}\) 为文本串。

定义一个函数 \(d(x,y)\geq 0\),且 \(d(x,y)=0\) 当且仅当 \(A_x=B_y\),即 \(A\) 的第 \(x\) 位和 \(B\) 的第 \(y\) 位匹配。

定义关于 \(x\)函数 \(P(x)\geq 0\) ,且 \(P(x)=0\) 当且仅当 \(A_{1:m}=B_{x:x+m-1}\),即 \(A\)\(B\)\(B\) 的第 \(x\) 位匹配。

那么应该 \(\forall i \in[1,m]\cap\Z, d(i,x+i-1)=0\) ,这样看上去没有优化前途。

既然他的值域都是正数,那么上面条件可以等价于 \(\displaystyle \sum_{i=1}^m\bigg[d(i,x+i-1)\bigg]=0\)

可以定义 \(d(x,y)=A_x\times B_y\times(A_x-B_y)^2\) ,当字符为通配符时其值为 \(0\) ,否则为其 ASCII 码的值。

那么 函数

\[\begin{aligned} P(x)&=\sum_{i=1}^mA_iB_{x+i-1}(A_i-B_{x+i-1})^2\\ &=\sum_{i=1}^m\left(A_i^3B_{x+i-1}+A_iB_{x+i-1}^3-2A_i^2B_{x+i-1}^2\right)\\ &=\sum_{i=1}^mA_i^3B_{x+i-1}+\sum_{i=1}^mA_iB_{x+i-1}^3-2\left(\sum_{i=1}^mA_i^2B_{x+i-1}^2\right) \end{aligned} \]

还是不好弄,考虑翻转字符串 \(A\)\(S\),那么 \(A_i=S_{m-i+1}\)

代入 函数

\[P(x)=\sum_{i=1}^mS_{m-i+1}^3B_{x+i-1}+\sum_{i=1}^mS_{m-i+1}B_{x+i-1}^3-2\left(\sum_{i=1}^mS_{m-i+1}^2B_{x+i-1}^2\right) \]

可以发现每一个加项都是关于 \(S_{m-i+1}\)\(B_{x+i-1}\) 的单项式,而他们下标加起来是一个常数 \(m+x\)

因此,关于 \(x\)函数 \(P(x)\) 就可以化成一个多项式计算。\(P(x)=\displaystyle\sum_{i+j=m+x}\bigg(S_i^3B_j+S_iB_j^3-2S_i^2B_j^2\bigg)\)

求和可以拆开,三个式子加起来。三个式子处理方式相同。比如第一个式子 \(\displaystyle\sum_{i+j=m+x}S_i^3B_j\) ,其实就是取多项式 \(S^{(3)}B\) 的第 \(m+x\) 项系数,其中 \(S^{(3)}\) 表示 \(S\) 中各项系数都变成 \(3\) 次方后的多项式。

所以暴力做一些多项式乘法就可以得出关于 \(x\)函数 \(P(x)\) 的值。

看上去要做 \(3\) 次多项式乘法,常数爆炸。

专业来说,这道题的思路就是构造 \(P(i)\) 的生成函数 \(F(x)=\displaystyle\sum_{i=1}^nP(i)x^i=S^{(3)}B+SB^{(3)}-2S^{(2)}B^{(2)}\)。算出后面多项式的系数就可以得到所有的 \(P(i)\)

点击查看代码
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <algorithm>
#include <cstring>

const int N=1e5+1;
int n,m;
char s[N],t[N];
int rplcment[N];
int bty[N<<1],mx; // bty[i]: the max beauty in the front i characters.
int ans[N<<1],mn; // ans[i]:the minimum of replacement when catch the beauty bty[i].

const int mod=998244353,PHI=998244352,G=3,invG=332748118;
inline int fastpow(long long a,int k) {
	int res=1;
	while(k) {
		if(k&1) res=a*res%mod;
		a=a*a%mod; k>>=1;
	}
	return res;
}
inline int plus(int x,int y) { x+=y-mod; return x+((x>>31)&mod); }
inline int minus(int x, int y) { x-=y; return x+((x>>31)&mod); }

typedef std::vector<int> poly;

inline void NTT(poly& a,const int limit,const int B[],const int I=1) {
	for(int i=1;i<limit;++i) if(i<B[i]) a[i]^=a[B[i]]^=a[i]^=a[B[i]];
	for(int slen=1;slen<limit;slen<<=1) {
		const int g = fastpow(I==1? G:invG,PHI/(slen<<1));
		for(int j=0;j<limit;j+=slen<<1) {
			long long rt=1;
			for(int opt=0;opt<slen;++opt) {
				const int x=a[j+opt],y=a[j+opt+slen]*rt%mod;
				a[j+opt]=plus(x,y); a[j+opt+slen]=minus(x,y);
				rt=rt*g%mod;
			}
		}
	}
}

inline poly operator*(poly a,poly b) {
	const int deg=a.size()+b.size()-2;
	int n,k=0;
	while((1<<++k)<=deg);
	static int *B=(int*)malloc(sizeof(int));
	if((n=1<<k)!=B[0]) {
		B=(int*)realloc(B,sizeof(int)*n);
		B[0]=0;
		for(int i=1;i<n;++i) B[i]=(B[i>>1]>>1)|((i&1)<<(k-1));
		B[0]=n;
	}
	a.resize(n+1); b.resize(n+1);
	NTT(a,n,B); NTT(b,n,B);
	for(int i=0;i<=n;++i) a[i]=1ll*a[i]*b[i]%mod;
	NTT(a,n,B,-1);
	a.resize(deg+1);
	const long long inv=fastpow(n,mod-2);
	for(int i=0;i<=deg;++i) a[i]=inv*a[i]%mod;
	return a;
}

inline int hsh(char s) {
	if(s=='?') return 0;
	return s;
}

inline void work() {
	poly a,b,c,d;
	a.resize(n); b.resize(m);
	for(int i=0;i<n;++i) a[i]=hsh(s[i]);
	for(int j=0;j<m;++j) b[j]=hsh(t[j]);
	std::reverse(b.begin(),b.end());
	c=a; d=b;
	for(int i=0;i<n;++i) c[i]=c[i]*c[i]*c[i];
	for(int j=0;j<m;++j) d[j]=d[j]*d[j]*d[j];
	c=b*c; d=a*d;
	for(int i=0;i<n;++i) a[i]=a[i]*a[i];
	for(int j=0;j<m;++j) b[j]=b[j]*b[j];
	a=a*b;
	memset(rplcment,0x80,sizeof(rplcment));
	for(int i=0;i<=n-m;++i) if(c[i+m-1]+d[i+m-1]-(a[i+m-1]<<1)==0) rplcment[i+m-1]=0;
	int tmp=0; for(int i=0;i<m;++i) tmp+=s[i]=='?';
	for(int i=m-1;i<n;++i) {
		if(rplcment[i]==0) rplcment[i]=tmp;
		tmp+=(s[i+1]=='?')-(s[i-m+1]=='?');
	}
	return;
}

int main() {
	scanf("%d%s%d",&n,s,&m);
	for(int i=0;i<m;++i) t[i]=(i&1? 'b':'a');
	work();
	int i=m-1;
	while(i<n+m) {
		if(rplcment[i]>=0) bty[i]=mx+1, ans[i]=mn+rplcment[i];
		else bty[i]=mx, ans[i]=mn;
		if(bty[++i-m]>mx) {
			mx=bty[i-m];
			mn=ans[i-m];
		}else if(bty[i-m]==mx) mn=std::min(mn,ans[i-m]);
	}
	printf("%d\n",mn);
	return 0;
}
posted @ 2022-10-08 16:43  IdanSuce  阅读(25)  评论(0编辑  收藏  举报