BZOJ4566:[Haoi2016]找相同字符
4566: [Haoi2016]找相同字符
Time Limit: 20 Sec Memory Limit: 256 MBSubmit: 545 Solved: 302
[Submit][Status][Discuss]
Description
给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两
个子串中有一个位置不同。
Input
两行,两个字符串s1,s2,长度分别为n1,n2。1 <=n1, n2<= 200000,字符串中只有小写字母
Output
输出一个整数表示答案
Sample Input
aabb
bbaa
bbaa
Sample Output
10
思路{
后缀数组套路题。
LCP长度大的能对长度小的做出贡献。
那把两个串连成一个。sa[i],sa[i-1],height[i]分组,用并查集维护集合在1串和二串中的个数,乘法原理统计答案。
}
#include <algorithm> #include <iostream> #include <cstring> #include <cstdlib> #include <cstdio> #include <vector> #include <cmath> #include <queue> #include <stack> #include <map> #include <set> #define inf (1<<30) #define il inline #define RG register #define LL long long #define maxx 200010*2 using namespace std; char s[maxx],t[maxx];int tong[maxx]; int len1,len2,X[maxx],Y[maxx],rnk[maxx],sa[maxx],height[maxx],LEN;LL A; il bool comp(int *r,int i,int j,int len){return r[i+len]==r[j+len]&&r[i]==r[j];} il void build_sa(int n){ int *x=X,*y=Y,*t,Max=88988; for(int i=0;i<n;++i)tong[x[i]=s[i]]++; for(int i=1;i<Max;++i)tong[i]+=tong[i-1]; for(int i=n-1;i!=-1;i--)sa[--tong[x[i]]]=i; for(int j=1,p=0,i;p<n;j<<=1,Max=p){ for(i=n-1,p=0;i>=n-j;--i)y[p++]=i; for(i=0;i<n;++i)if(sa[i]>=j)y[p++]=sa[i]-j; memset(tong,0,sizeof(tong)); for(i=0;i<n;++i)tong[x[y[i]]]++; for(i=1;i<=Max;++i)tong[i]+=tong[i-1]; for(i=n-1;i!=-1;i--)sa[--tong[x[y[i]]]]=y[i]; for(t=x,x=y,y=t,i=1,p=1,x[sa[0]]=0;i<n;++i) x[sa[i]]=comp(y,sa[i],sa[i-1],j)?p-1:p++; } } il void geth(){ int i,j,k=0; for(i=1;i<=LEN;++i)rnk[sa[i]]=i; for(i=0;i<LEN;height[rnk[i++]]=k) for((k?k--:0),j=sa[rnk[i]-1];s[j+k]==s[i+k];k++); } struct segment{ int l,r,len; segment() {} segment(int _l,int _r,int L):l(_l),r(_r),len(L) {}; }w[maxx];int cnt,fa[maxx],sz[maxx][3],c[maxx];LL ans[maxx]; bool Comp(const segment & a,const segment & b){return a.len>b.len;} il int find(int x){if(fa[x]!=x)fa[x]=find(fa[x]);return fa[x];} il void Insert(int x,int y,int H){fa[x]=y;ans[H]+=sz[y][1]*sz[x][2]+sz[x][1]*sz[y][2],sz[y][1]+=sz[x][1],sz[y][2]+=sz[x][2];} il void work(){ scanf("%s%s",s,t);len1=strlen(s),len2=strlen(t); s[len1]='#';for(int i=0;i<len1;++i)c[i]=1;for(int i=0;i<len2;++i)s[len1+i+1]=t[i],c[len1+i+1]=2; LEN=len1+len2+1;s[LEN]=0;build_sa(LEN+1);geth(); for(int i=2;i<=LEN;++i)w[i-1]=segment(sa[i],sa[i-1],height[i]); sort(w+1,w+LEN,Comp);for(int i=0;i<LEN;++i)fa[i]=i,sz[i][c[i]]++; for(int i=1;i<LEN;++i){ int x=find(w[i].l),y=find(w[i].r); if(x!=y)Insert(x,y,w[i].len); }for(int i=LEN;i;--i)A+=ans[i]*i;printf("%lld",A); } int main(){ freopen("1.in","r",stdin); freopen("1.out","w",stdout); work();return 0; }