【BZOJ3160】万径人踪灭

题面

http://darkbzoj.tk/problem/3160

题解

#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#define ri register int
#define N 500050
#define mod 1000000007
using namespace std;

const double pi=acos(-1.0);
struct fushu{double x,y;} A[N<<3],B[N<<3];
fushu operator + (fushu a,fushu b) {return (fushu){a.x+b.x,a.y+b.y};}
fushu operator - (fushu a,fushu b) {return (fushu){a.x-b.x,a.y-b.y};}
fushu operator * (fushu a,fushu b) {return (fushu){a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};}

int n,cnt,limit,r[N<<3],c[N<<3];
int hw[N<<1];
int tw[N<<3];
char s0[N];
char s1[N<<1];

void FFT(fushu *a,int opt){
  for (ri i=0;i<limit;i++) if (i<r[i]) swap(a[i],a[r[i]]);
  for (ri mid=1;mid<limit;mid<<=1) {
    fushu wn=(fushu){cos(pi/mid),sin(pi/mid)*opt};
    for (ri j=0;j<limit;j+=2*mid) {
      fushu w0=(fushu){1,0};
      for (ri k=0;k<mid;k++,w0=w0*wn) {
        fushu x=a[j+k],y=w0*a[j+k+mid];
        a[j+k]=x+y; a[j+k+mid]=x-y;
      }
    }
  }
}

void init() {
  cnt=0;limit=1;
  while (limit<=4*n+2) limit<<=1,cnt++;
  for (ri i=0;i<limit;i++) r[i]=((r[i>>1]>>1)|((i&1)<<(cnt-1)));
}

void work(int ch) {
  for (ri i=0;i<limit;i++) A[i]=B[i]=(fushu){0.0,0.0};
  for (ri i=1;i<=2*n+1;i++) if (s1[i]==ch) A[i]=(fushu){1.0,0.0}; else A[i]=(fushu){0.0,0.0};
  for (ri i=1;i<=2*n+1;i++) B[i]=A[i];
  FFT(A,1); FFT(B,1);
  for (ri i=0;i<limit;i++) A[i]=A[i]*B[i];
  FFT(A,-1);
  for (ri i=0;i<limit;i++) c[i]+=(int)(A[i].x/limit+0.5);
}

int main(){
  tw[0]=1;
  for (ri i=1;i<(N<<3);i++) tw[i]=(tw[i-1]+tw[i-1])%mod;
  scanf("%s",s0+1);
  n=strlen(s0+1);
  for (ri i=1;i<=2*n+1;i++) if (i%2==0) s1[i]=s0[i/2]; else s1[i]='#';
  init();
  work('a');work('b');
  long long ans=0;
  for (ri i=1;i<=2*n+1;i++) {
    if (c[2*i]%2==0) c[2*i]/=2; else (c[2*i]+=1)/=2;
    ans+=tw[c[2*i]]-1;
    ans=(ans%mod+mod)%mod;
  }
  int mid=0,maxr=0;
  for (ri i=1;i<=2*n+1;i++) {
    if (i<=maxr) hw[i]=min(hw[2*mid-i],maxr-i);
    while (i+hw[i]+1<=2*n+1 && i-hw[i]-1>=1 && s1[i+hw[i]+1]==s1[i-hw[i]-1]) hw[i]++;
    if (i+hw[i]>maxr) maxr=i+hw[i],mid=i;
  }
  for (ri i=1;i<=2*n+1;i++) {
    if (s1[i]=='#') ans-=hw[i]/2; else ans-=(hw[i]+1)/2;
    ans=(ans%mod+mod)%mod;
  }
  cout<<ans<<endl;
}

 

posted @ 2019-07-31 22:51  HellPix  阅读(176)  评论(0编辑  收藏  举报