[BZOJ4566] [HAOI2016]找相同字符
Description
给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两
个子串中有一个位置不同。
Input
两行,两个字符串s1,s2,长度分别为n1,n2。1 <=n1, n2<= 200000,字符串中只有小写字母
Output
输出一个整数表示答案
Sample Input
aabb
bbaa
Sample Output
10
Solution
对第一个串建立后缀自动机,然后把第二个串丢上面跑。
预处理出自动机每个点的\(sz=|right|\),还有\(sum_i\)表示\(parent\)树上根到\(i\)的答案前缀和,也就是\(\sum_x sz_x\cdot (maxl_x-maxl_{par_x})\) 。
那么根据当前匹配的长度统计答案就好了。
时间复杂度\(O(n)\)。
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define min(x,y) ((x)<(y)?(x):(y))
#define max(x,y) ((x)>(y)?(x):(y))
void read(int &x) {
x=0;int f=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}
void print(ll x) {
if(x<0) putchar('-'),x=-x;
if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(ll x) {if(!x) putchar('0');else print(x);putchar('\n');}
const int maxn = 4e5+10;
char s[maxn];
ll sum[maxn];
int n,qs=1,cnt=1,lstp=1;
int tr[maxn][26],ml[maxn],sz[maxn],par[maxn],t[maxn],r[maxn];
void append(int x) {
int p=lstp,np=++cnt;ml[np]=ml[p]+1,lstp=np;sz[np]=1;
for(;p&&tr[p][x]==0;p=par[p]) tr[p][x]=np;
if(!p) return par[np]=qs,void();
int q=tr[p][x];
if(ml[p]+1<ml[q]) {
int nq=++cnt;ml[nq]=ml[p]+1;
memcpy(tr[nq],tr[q],sizeof tr[nq]);
par[nq]=par[q],par[q]=par[np]=nq;
for(;p&&tr[p][x]==q;p=par[p]) tr[p][x]=nq;
} else par[np]=q;
}
void prepare() {
for(int i=1;i<=cnt;i++) t[ml[i]]++;
for(int i=1;i<=n;i++) t[i]+=t[i-1];
for(int i=1;i<=cnt;i++) r[t[ml[i]]--]=i;
for(int i=cnt;i;i--) sz[par[r[i]]]+=sz[r[i]];
for(int i=1,p=r[i];i<=cnt;i++,p=r[i])
sum[p]=sum[par[p]]+sz[p]*(ml[p]-ml[par[p]]);
}
ll solve() {
ll p=qs,ans=0,len=0;
for(int i=1,x;i<=n;i++) {
x=s[i]-'a';
while((!tr[p][x])&&p) p=par[p];
if(!p) p=qs,len=0;
else {
len=min(len,ml[p])+1;p=tr[p][x];
ans+=1ll*(len-ml[par[p]])*sz[p];
ans+=sum[par[p]];
}
}return ans;
}
int main() {
scanf("%s",s+1);n=strlen(s+1);
for(int i=1;i<=n;i++) append(s[i]-'a');
prepare();
scanf("%s",s+1);n=strlen(s+1);
write(solve());
return 0;
}