【ABC196F】Substring 2(多项式乘法)

我竟然能在 AT 当场做出 F 题!

哦,是 ABC 啊,没事了。


以下的字符串均从 \(1\) 开始记位。以下设 \(S_i\) 表示字符串 \(S\) 的第 \(i\) 位,\(S(l,r)\) 表示字符串 \(S\) 的第 \(l\) 位到第 \(r\) 位组成的子串,也可以表示字符串 \(S\) 的第 \(l\) 位到第 \(r\) 位组成的序列

\(f_{i,j}\) 表示 \(S\) 串从位置 \(i\) 往后 \(j\) 个字符所组成的串 和 \(T\) 串从位置 \(1\) 往后 \(j\) 个字符所组成的串 有多少位不相等,即 \(S(i,i+j-1)\)\(T(1,j)\) 的相差字符个数。

\(n=|S|,m=|T|\),那么我们要求的即为:

\[\min\limits_{i=1}^{n-m+1} f_{i,m} \]

显然有转移:

\[\begin{aligned} &f_{i,0}=0\\ &f_{i,j}=f_{i,j-1}+\big[S_{i+j-1}\neq T_{j}\big] \end{aligned} \]

以下定义两个长度相等的序列 \(A,B\) 相加表示 \(A\)\(B\) 各位相加后组成的序列。即如果 \(C=A+B\),那么 \(C_i=A_i+B_i\)

两个序列相减的定义类似。

一个常数乘一个序列的定义类似。

设序列 \(F_j\) 表示长度为 \(n-m+1\) 的序列 \(f_{1,j},f_{2,j},\cdots,f_{n-m+1,j}\),那么我们要求的就是序列 \(F_m\) 每一位上的值的最小值。

设序列 \(G_j\) 表示长度为 \(n-m+1\) 的序列 \([S_{1+j-1}\neq T_j],[S_{2+j-1}\neq T_j],\cdots,[S_{(n-m+1)+j-1}\neq T_j]\)

那么根据之前得到的状态转移方程,有:

\[F_{j}=F_{j-1}+G_{j} \]

考虑转化 \(G_j\)

首先,在 \(a,b\) 都是 \(0\)\(1\) 时,有 \([a\neq b]=a \oplus b\)。(\(\oplus\) 指异或运算)

那么 \(G_j\) 就可以变成序列 \(S_{1+j-1}\oplus T_j,S_{2+j-1}\oplus T_j,\cdots,S_{(n-m+1)+j-1}\oplus T_j\)

再进一步,当 \(T_j=0\) 时,\(G_j\) 就是 \(S(j,j+n-m)\);当 \(T_j=1\) 时,\(G_j\) 就是 \(S(j,j+n-m)\) 每一位都取反后形成的序列。

更进一步,注意到 \(S(j,j+n-m)\) 每一位都取反后形成的序列其实就是序列 \((11...1)_{n-m+1}-S(j,j+n-m)\)。(这里的 \((11...1)_{n-m+1}\) 指的是一个长度为 \(n-m+1\) 的全为 \(1\) 的序列)。

那么有:

\[F_j=F_{j-1}+ \begin{cases} S(j,j+n-m)&\operatorname{if }T_j=0\\ (11...1)_{n-m+1}-S(j,j+n-m)&\operatorname{if }T_j=1 \end{cases} \]

那么:

\[F_m=\sum_{j=1}^m\begin{cases} S(j,j+n-m)&\operatorname{if }T_j=0\\ (11...1)_{n-m+1}-S(j,j+n-m)&\operatorname{if }T_j=1 \end{cases} \]

我们可以把 \((11...1)_{n-m+1}\) 抽出来最后算:假设 \(T\) 里面有 \(tot\) 位是 \(1\),那么:

\[F_m=tot\times (11...1)_{n-m+1}+\sum_{j=1}^m(-1)^{T_j}S(j,j+n-m) \]

现在我们只需要计算 \(\sum\limits_{j=1}^m(-1)^{T_j}S(j,j+n-m)\) 即可。

这个式子可以理解成为:有一个长度为 \(n-m+1\) 的滑动窗口在 \(S\) 上从左往右移动,每移到某一个位置就把当前记录的答案序列加上/减去滑动窗口内框住的序列。

由于这个滑动窗口是一直在 \(S\) 上的,所以我们可以用多项式来优化这个操作:

\(A(x)=S_1x+S_2x^2+\cdots+S_nx^n\)

我们考虑求出一个多项式 \(C(x)\),然后用 \(C(x)\) 的第 \(m\) 位到第 \(n\) 位(即 \(x^m\)\(x^n\) 的系数)这 \(n-m+1\) 位来表示 \(\sum\limits_{j=1}^m(-1)^{T_j}S(j,j+n-m)\) 计算后得到的长度为 \(n-m+1\) 的序列。(注意,\(C(x)\) 的其他位算出来是啥我们不用管,只用保证这 \(n-m+1\) 位是对的就行)

容易得到:

\[\begin{aligned} C(x)=&\sum_{j=1}^m (-1)^{T_j} \left(S_jx^j+\cdots+S_{j+n-m}x^{j+n-m}\right)x^{m-j}\\ =&\sum_{j=1}^m (-1)^{T_j} A(x) x^{m-j}\\ =&A(x)\times \left(\sum_{j=1}^m (-1)^{T_j} x^{m-j}\right) \end{aligned} \]

乘号两边都是多项式,用 FFT 或 NTT 优化多项式乘法即可。

代码如下:

#include<bits/stdc++.h>

#define LN 22
#define N 1000010
#define INF 0x7fffffff

using namespace std;

namespace modular
{
	const int mod=998244353;
	inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
	inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
	inline int mul(int x,int y){return 1ll*x*y%mod;}
}using namespace modular;

inline int read()
{
	int x=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9')
	{
		if(ch=='-') f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9')
	{
		x=(x<<1)+(x<<3)+(ch^'0');
		ch=getchar();
	}
	return x*f;
}

inline int poww(int a,int b)
{
    int ans=1;
    while(b)
    {
        if(b&1) ans=mul(ans,a);
        a=mul(a,a);
        b>>=1;
    }
    return ans;
}

int n,m;
int a[N<<3],b[N<<3];
int rev[N<<3],w[LN][N<<3][2];
char s[N],t[N];

void init(int limit)
{
    for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
    {
        int len=mid<<1;
        int gn=poww(3,(mod-1)/len);
        int ign=poww(gn,mod-2);
        int g=1,ig=1;
        for(int j=0;j<mid;g=mul(g,gn),ig=mul(ig,ign),j++)
            w[bit][j][0]=g,w[bit][j][1]=ig;
    }
}
 
void NTT(int *a,int limit,int opt)
{
    opt=(opt<0);
    for(int i=0;i<limit;i++)
        rev[i]=(rev[i>>1]>>1)|((i&1)*(limit>>1));
    for(int i=0;i<limit;i++)
        if(i<rev[i]) swap(a[i],a[rev[i]]);
    for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
    {
        for(int i=0,len=mid<<1;i<limit;i+=len)
        {
            for(int j=0;j<mid;j++)
            {
                int x=a[i+j],y=mul(w[bit][j][opt],a[i+mid+j]);
                a[i+j]=add(x,y),a[i+mid+j]=dec(x,y);
            }
        }
    }
    if(opt)
    {
        int tmp=poww(limit,mod-2);
        for(int i=0;i<limit;i++)
            a[i]=mul(a[i],tmp);
    }
}

int main()
{
	scanf("%s%s",s+1,t+1);
	n=strlen(s+1),m=strlen(t+1);
	for(int i=1;i<=n;i++) a[i]=s[i]-'0';
	int tot=0;
	for(int i=1,j=m-1;i<=m;i++,j--)
	{
		if(t[i]=='0') b[j]=1;
		else b[j]=-1,tot++;
	}
	int limit=1;
	while(limit<=(n<<1)) limit<<=1;
	init(limit);
	NTT(a,limit,1),NTT(b,limit,1);
	for(int i=0;i<limit;i++) a[i]=mul(a[i],b[i]);
	NTT(a,limit,-1);
	int ans=INF;
	for(int i=m;i<=n;i++)
		ans=min(ans,add(a[i],tot));
	printf("%d\n",ans);
	return 0;
}
/*
0101010
1010101
*/
posted @ 2022-10-28 18:24  ez_lcw  阅读(25)  评论(0)    收藏  举报