CF613E Puzzle Lover 思考--zhengjun
题很简单,一遍写对却比较困难。
犯的错误:
-
预处理 \({base}^i\) 时应该要处理到 \(\max\{n,m\}\);
-
去重的时候(reduce 函数)特判 \(m=1,2\)。
代码
#include<bits/stdc++.h>
using namespace std;
using ll=long long;
const int N=2e3+10,mod=1e9+7,base=23333;
int n,m;
char a[2][N],b[N];
int pw[N],f[2][N],g[2][N],pre[N],suf[N];
int Hash1(int t,int l,int r){
return (f[t][r]+1ll*(mod-pw[r-l+1])*f[t][l-1])%mod;
}
int Hash2(int t,int l,int r){
return (g[t][l]+1ll*(mod-pw[r-l+1])*g[t][r+1])%mod;
}
int merge(int x,int y,int len){
return (1ll*x*pw[len]+y)%mod;
}
int ans,dp[N][N][2];
void solve(){
memset(dp,0,sizeof dp);
for(int i=1;i<=m;i++){
pre[i]=(1ll*pre[i-1]*base+b[i])%mod;
}
for(int i=m;i>=1;i--){
suf[i]=(suf[i+1]+1ll*pw[m-i]*b[i])%mod;
}
for(int i=0;i<=n;i++)dp[i][0][0]=dp[i][0][1]=1;
for(int i=1;i<=n;i++){
for(int j=1;j<=m;j++){
for(int t=0;t<2;t++){
if(b[j]!=a[t][i])continue;
(dp[i][j][t]+=dp[i-1][j-1][t])%=mod;
if(j>1&&b[j-1]==a[!t][i])(dp[i][j][t]+=dp[i-1][j-2][!t])%=mod;
}
}
for(int t=0;t<2;t++){
for(int len=2;len<=i;len++){
int l=i-len+1,r=i,j=len*2;
if(len*2>m)continue;
if(merge(Hash2(!t,l,r),Hash1(t,l,r),len)==pre[j])
++dp[i][j][t]%=mod;
}
}
}
for(int i=1;i<=n;i++){
for(int t=0;t<2;t++){
(ans+=dp[i][m][t])%=mod;
for(int len=2;len<=n-i+1;len++){
int l=i,r=i+len-1;
if(len*2>m)continue;
if(merge(Hash1(t,l,r),Hash2(!t,l,r),len)==suf[m-len*2+1])
(ans+=dp[i-1][m-len*2][t])%=mod;
}
}
}
}
void reduce(){
if(m>1){
if(m&1)return;
for(int i=1;i<=n;i++){
for(int t=0;t<2;t++){
int len=m/2,l=i,r=i+len-1;
if(r>n)continue;
if(merge(Hash1(t,l,r),Hash2(!t,l,r),len)==pre[m])ans--;
}
}
if(m==2)return;
for(int i=1;i<=n;i++){
for(int t=0;t<2;t++){
int len=m/2,l=i-len+1,r=i;
if(l<1)continue;
if(merge(Hash2(!t,l,r),Hash1(t,l,r),len)==pre[m])ans--;
}
}
}else{
for(int i=1;i<=n;i++){
for(int t=0;t<2;t++){
ans-=b[m]==a[t][i];
}
}
}
(ans+=mod)%=mod;
}
int main(){
freopen(".in","r",stdin);
//freopen(".out","w",stdout);
scanf("%s%s%s",a[0]+1,a[1]+1,b+1);
n=strlen(a[0]+1),m=strlen(b+1);
for(int i=pw[0]=1;i<=max(n,m);i++)pw[i]=1ll*pw[i-1]*base%mod;
for(int t=0;t<2;t++){
for(int i=1;i<=n;i++){
f[t][i]=(1ll*f[t][i-1]*base+a[t][i])%mod;
}
for(int i=n;i>=1;i--){
g[t][i]=(1ll*g[t][i+1]*base+a[t][i])%mod;
}
}
solve();
reverse(b+1,b+1+m);
solve();
reduce();
cout<<ans;
return 0;
}