[HAOI2016]找相同子串
这题感觉有点坑啊。
题目还是不难想的,先对一个字符串建后缀自动机,然后拿另一个字符串在上面跑。
假设当前跑到了p点,匹配长度为len。
那么当前会对答案产生贡献的串是哪些呢?
显然当前会对p及p到根的链产生贡献。这样显然可以用树形dp优化。
同时需要差分(我也不知道这是否是必须的)。
下面有update。
#include<iostream> #include<cstdio> #include<cstdlib> #include<string> #include<cstring> #include<cmath> #include<ctime> #include<algorithm> #include<iomanip> #include<set> #include<map> #include<queue> using namespace std; #define mem1(i,j) memset(i,j,sizeof(i)) #define mem2(i,j) memcpy(i,j,sizeof(i)) #define LL long long #define up(i,j,n) for(int i=(j);i<=(n);i++) #define FILE "dealing" #define poi vec #define eps 1e-10 #define db double const int maxn=401000,inf=1000000000,mod=1000000007; int read(){ int x=0,f=1,ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch<='9'&&ch>='0'){x=(x<<1)+(x<<3)+ch-'0',ch=getchar();} return f*x; } bool cmax(int& a,int b){return a<b?a=b,true:false;} bool cmin(int& a,int b){return a>b?a=b,true:false;} struct SAM{ int pre[maxn],c[maxn][27],step[maxn],sa[maxn],cou[maxn],val[maxn],cnt,now,Len; SAM(){mem1(pre,0);mem1(c,0);mem1(step,0);cnt=now=1;} int extend(int x){ int np,nq,q,p; p=now;now=np=++cnt;step[np]=step[p]+1;val[np]++; while(p&&!c[p][x])c[p][x]=np,p=pre[p]; if(!p)pre[np]=1; else { q=c[p][x]; if(step[q]==step[p]+1)pre[np]=q; else { step[nq=++cnt]=step[p]+1; mem2(c[nq],c[q]); pre[nq]=pre[q]; pre[q]=pre[np]=nq; while(p&&c[p][x]==q)c[p][x]=nq,p=pre[p]; } } } int getsort(){ up(i,1,cnt)cou[step[i]]++; up(i,1,cnt)cou[i]+=cou[i-1]; for(int i=cnt;i>=1;i--)sa[cou[step[i]]--]=i; for(int i=cnt;i>=1;i--)val[pre[sa[i]]]+=val[sa[i]]; } int walkprepare(){now=1,Len=0;} int walk(int x){ while(pre[now]&&!c[now][x])now=pre[now],Len=step[now]; if(!c[now][x])return 0; Len++;now=c[now][x];return Len; } int build(char* s){ int n=strlen(s+1); up(i,1,n)extend(s[i]-'a'); getsort();walkprepare(); } }a; char s[maxn]; LL ans[maxn],w[maxn],e[maxn]; int main(){ freopen(FILE".in","r",stdin); freopen(FILE".out","w",stdout); scanf("%s",s+1); a.build(s); scanf("%s",s+1); int n=strlen(s+1); up(i,1,n){ int m=a.walk(s[i]-'a'); LL p=a.now; ans[min(m,a.step[p])]+=a.val[p]; w[p]+=a.val[p];e[p]++; } for(int i=a.cnt;i>=1;i--){ ans[a.step[a.sa[i]]]+=e[a.sa[i]]*a.val[a.sa[i]]-w[a.sa[i]]; e[a.pre[a.sa[i]]]+=e[a.sa[i]]; w[a.pre[a.sa[i]]]+=e[a.sa[i]]*a.val[a.sa[i]];//写的丑丑的dp } LL sum=0; for(int i=a.cnt;i>=1;i--)ans[i]+=ans[i+1],sum+=ans[i]; cout<<sum<<endl; return 0; }
update:
发现自己的做法麻烦了,我们又不用求长度为i的子串的方案,ans[]完全可以省略,直接预处理到这个节点的方案数,最后加上去即可。
code:修改版本
#include<iostream> #include<cstdio> #include<cstdlib> #include<string> #include<cstring> #include<cmath> #include<ctime> #include<algorithm> #include<iomanip> #include<set> #include<map> #include<queue> using namespace std; #define mem1(i,j) memset(i,j,sizeof(i)) #define mem2(i,j) memcpy(i,j,sizeof(i)) #define LL long long #define up(i,j,n) for(int i=(j);i<=(n);i++) #define FILE "find_2016" #define poi vec #define eps 1e-10 #define db double const int maxn=401000,inf=1000000000,mod=1000000007; int read(){ int x=0,f=1,ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch<='9'&&ch>='0'){x=(x<<1)+(x<<3)+ch-'0',ch=getchar();} return f*x; } bool cmax(int& a,int b){return a<b?a=b,true:false;} bool cmin(int& a,int b){return a>b?a=b,true:false;} struct SAM{ int pre[maxn],c[maxn][26],step[maxn],sa[maxn],cou[maxn],val[maxn],cnt,now,Len; LL sum[maxn]; SAM(){mem1(pre,0);mem1(c,0);mem1(step,0);mem1(val,0);mem1(sum,0);cnt=now=1;} int extend(int x){ int np,nq,q,p; p=now;now=np=++cnt;step[np]=step[p]+1;val[np]++; while(p&&!c[p][x])c[p][x]=np,p=pre[p]; if(!p)pre[np]=1; else { q=c[p][x]; if(step[q]==step[p]+1)pre[np]=q; else { step[nq=++cnt]=step[p]+1; mem2(c[nq],c[q]); pre[nq]=pre[q]; pre[q]=pre[np]=nq; while(p&&c[p][x]==q)c[p][x]=nq,p=pre[p]; } } } int getsort(){ up(i,1,cnt)cou[step[i]]++; up(i,1,cnt)cou[i]+=cou[i-1]; for(int i=cnt;i>=1;i--)sa[cou[step[i]]--]=i; for(int i=cnt;i>=1;i--)val[pre[sa[i]]]+=val[sa[i]]; up(i,1,cnt)sum[sa[i]]+=sum[pre[sa[i]]]+(step[sa[i]]-step[pre[sa[i]]])*val[sa[i]]; } int walkprepare(){now=1,Len=0;} int walk(int x){ while(pre[now]&&!c[now][x])now=pre[now],Len=step[now]; if(!c[now][x])return 0; Len++;now=c[now][x];return Len; } int build(char* s){ int n=strlen(s+1); up(i,1,n)extend(s[i]-'a'); getsort();walkprepare(); } }a; char s[maxn]; LL ans=0; int main(){ freopen(FILE".in","r",stdin); freopen(FILE".out","w",stdout); scanf("%s",s+1); a.build(s); scanf("%s",s+1); int n=strlen(s+1); up(i,1,n){ int m=a.walk(s[i]-'a'); LL p=a.now; ans+=a.sum[a.pre[p]]+a.val[p]*(m-a.step[a.pre[p]]); } cout<<ans<<endl; return 0; }