bzoj 4503

一道FFT好题,也是FFT作用于字符串匹配的一个典型题

首先我们解释为什么FFT可以用于字符串匹配:

我们发现:设两个字符串为$S,T$,且两个串长度均为$l$,那么两个字符串相等的充要条件是:

 $ S_i==T_i (i\in [1,l])$

那么我们变下形,也就是:

$S_i-T-i==0 (i\in[1,l])$

这也就是:

$\sum_{i=1}^{l}S_i-T_i==0$

但是注意:这样变形以后并不是充要条件,因为可能出现相反数之和为0的情况,所以我们保证值非负的情况下和为0,所以两个字符串相等的充要条件是:

$\sum_{i=1}^{l}(S_i-T_i)^2==0$

令下标从0开始,则:

$\sum_{i=0}^{l-1}(S_i-T_i)^2==0$

将平方式展开,得:

$\sum_{i=0}^{l-1}(S_i)^2+(T_i)^2-2*S_i*T_i==0$

可以发现,平方项之和是互不影响的,所以我们只研究这个表达式:

$\sum_{i=0}^{l-1}2*S_i*T_i$

按照常规套路,我们翻转T串,于是原式变为:

$\sum_{i=0}^{l-1}2*S_i*T_{l-i-1}$

我们将$l$减一,可将原式变为:

$\sum_{i=0}^{l}2*S_i*T_{l-i}$

这不就是一个非常典型的卷积了嘛

 那么,如果我们令S为大串,使T在S中匹配,那么我们同样只需将T翻转后与S整个做卷积,然后提取对应位置上的值进行比较就能得知是否匹配了

所以这是直接利用FFT进行字符串匹配的代码

#include <cstdio>
#include <cmath>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#include <queue>
#include <stack>
#define ll long long
using namespace std;
const double pi=acos(-1.0);
struct cp
{
    double x,y;
};
cp operator + (cp a,cp b)
{
    return (cp){a.x+b.x,a.y+b.y};
}
cp operator - (cp a,cp b)
{
    return (cp){a.x-b.x,a.y-b.y};
}
cp operator * (cp a,cp b)
{
    return (cp){a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};
}
char sa[1000005];
char sb[1000005];
int to[(1<<21)+5];
int n,m,lim=1,l;
void FFT(cp *a,int len,int k)
{
    for(int i=0;i<len;i++)if(i<to[i])swap(a[i],a[to[i]]);
    for(int i=1;i<len;i<<=1)
    {
        cp w0=(cp){cos(pi/i),k*sin(pi/i)};
        for(int j=0;j<len;j+=(i<<1))
        {
            cp w=(cp){1,0};
            for(int o=0;o<i;o++,w=w*w0)
            {
                cp w1=a[j+o],w2=a[j+o+i]*w;
                a[j+o]=w1+w2;
                a[j+o+i]=w1-w2;
            }
        }
    }
}
ll s1[1000005],s2[1000005],f[1000005];
cp a[(1<<21)+5],b[(1<<21)+5],c[(1<<21)+5];
int main()
{
    scanf("%s%s",sa,sb);
    n=strlen(sa),m=strlen(sb);
    s1[0]=1ll*(sa[0]-'A'+1)*(sa[0]-'A'+1);
    for(int i=1;i<n;i++)s1[i]=s1[i-1]+1ll*(sa[i]-'A'+1)*(sa[i]-'A'+1);
    s2[0]=1ll*(sb[0]-'A'+1)*1ll*(sb[0]-'A'+1);
    for(int i=1;i<m;i++)s2[i]=s2[i-1]+1ll*(sb[i]-'A'+1)*(sb[i]-'A'+1);
    for(int i=0;i<n;i++)a[i].x=(double)(sa[i]-'A'+1);
    for(int i=0;i<m;i++)b[i].x=(double)(sb[m-i-1]-'A'+1);
    while(lim<=2*n)lim<<=1,l++;
    for(int i=1;i<lim;i++)to[i]=((to[i>>1]>>1)|((i&1)<<(l-1)));
    FFT(a,lim,1),FFT(b,lim,1);
    for(int i=0;i<lim;i++)c[i]=a[i]*b[i];
    FFT(c,lim,-1);
    for(int i=m-1;i<n;i++)f[i]=(ll)(c[i].x/lim+0.5);
    int ans=0;
    if(s1[m-1]+s2[m-1]-2*f[m-1]==0)ans++;
    for(int i=1;i+m-1<n;i++)if(s1[i+m-1]-s1[i-1]+s2[m-1]-2*f[i+m-1]==0)ans++;
    printf("%d\n",ans);
    return 0;
}

有了这些基础知识,做这道题就容易多了

这道题的问题在于,有一些通配符,所以直接比较是不正确的。

那么我们考虑改进上述表达式:

我们知道,上述表达式成立的核心在于将两个字符相等转化为一个表达式的值为0

那么我们只需让有通配符时上述表达式值恒为0即可

怎么办?

我们在将字符串转化为数值时,将通配符设成0,然后将表达式变成下面这个样子:

$\sum_{i=0}^{l}T_i*(S_i-T_i)^2==0$

也就是:

$\sum_{i=0}^{l}(T_i)^3+T_i*(S_i)^2-2*S_i*(T_i)^2==0$

把后两项按上面的方法转化为卷积计算即可

#include <cstdio>
#include <cmath>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#include <queue>
#include <stack>
#define ll long long
using namespace std;
const double pi=acos(-1.0);
struct cp
{
    double x,y;
    friend cp operator + (cp a,cp b)
    {
        return (cp){a.x+b.x,a.y+b.y};
    }
    friend cp operator - (cp a,cp b)
    {
        return (cp){a.x-b.x,a.y-b.y};
    }
    friend cp operator * (cp a,cp b)
    {
        return (cp){a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};
    }
};
char ssa[1000005],ssb[1000005];
int sa[1000005],sb[1000005];
int to[(1<<21)+5];
int n,m,lim=1,l;
void FFT(cp *a,int len,int k)
{
    for(int i=0;i<len;i++)if(i<to[i])swap(a[i],a[to[i]]);
    for(int i=1;i<len;i<<=1)
    {
        cp w0=(cp){cos(pi/i),k*sin(pi/i)};
        for(int j=0;j<len;j+=(i<<1))
        {
            cp w=(cp){1,0};
            for(int o=0;o<i;o++,w=w*w0)
            {
                cp w1=a[j+o],w2=a[j+o+i]*w;
                a[j+o]=w1+w2,a[j+o+i]=w1-w2;
            }
        }
    }
}
cp a[(1<<21)+5],b[(1<<21)+5],c[(1<<21)+5];
ll f[2000005],g[2000005];
ll s3;
int ret[2000005];
int main()
{
    scanf("%s%s",ssa,ssb);
    n=strlen(ssa),m=strlen(ssb);    
    for(int i=0;i<n;i++)sa[i]=(ssa[i]=='?')?0:ssa[i]-'a'+1;
    for(int i=0;i<m;i++)sb[i]=(ssb[i]=='?')?0:ssb[i]-'a'+1;
    for(int i=0;i<m;i++)s3+=1ll*sb[i]*sb[i]*sb[i];
    for(int i=0;i<n;i++)a[i].x=(double)sa[i]*sa[i];
    for(int i=0;i<m;i++)b[i].x=(double)sb[m-i-1];
    while(lim<=2*n)lim<<=1,l++;
    for(int i=1;i<lim;i++)to[i]=((to[i>>1]>>1)|((i&1)<<(l-1)));
    FFT(a,lim,1),FFT(b,lim,1);
    for(int i=0;i<lim;i++)c[i]=a[i]*b[i];
    FFT(c,lim,-1);
    for(int i=0;i<lim;i++)f[i]=(ll)(c[i].x/lim+0.5);
    memset(a,0,sizeof(a)),memset(b,0,sizeof(b)),memset(c,0,sizeof(c));
    for(int i=0;i<n;i++)a[i].x=(double)sa[i];
    for(int i=0;i<m;i++)b[i].x=(double)sb[m-i-1]*sb[m-i-1];
    FFT(a,lim,1),FFT(b,lim,1);
    for(int i=0;i<lim;i++)c[i]=a[i]*b[i];
    FFT(c,lim,-1);
    for(int i=0;i<lim;i++)g[i]=2*(ll)(c[i].x/lim+0.5);
    int ans=0;
    for(int i=0;i+m-1<n;i++)if(f[i+m-1]+s3-g[i+m-1]==0)ans++,ret[ans]=i;
    printf("%d\n",ans);
    for(int i=1;i<=ans;i++)printf("%d\n",ret[i]);
    return 0;
}

 

posted @ 2019-05-04 20:18  lleozhang  Views(204)  Comments(0Edit  收藏  举报
levels of contents