BZOJ4566: [Haoi2016]找相同字符
4566: [Haoi2016]找相同字符
Time Limit: 20 Sec Memory Limit: 256 MBSubmit: 1354 Solved: 795
[Submit][Status][Discuss]
Description
给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两
个子串中有一个位置不同。
Input
两行,两个字符串s1,s2,长度分别为n1,n2。1 <=n1, n2<= 200000,字符串中只有小写字母
Output
输出一个整数表示答案
Sample Input
aabb
bbaa
bbaa
Sample Output
10
HINT
Source
题解:这种题首先很套路,只有两个串肯定只给一个建后缀自动机一个去在上面跑。跑的方式就是跟AC自动机那样,有就走否则就跳后缀Link。然后对每个匹配到的位置g[now]++。由于这题也需要拓扑排序,这里就是SAM的性质,sr[sr[a].link].len<sr[a].len,就直接用类似于SA中的基数排序就可以排出拓扑排序。然后就从后往前就可以得到每个状态的出现次数,不需要建图后树dp。这里要注意的是,我们在extend的时候,将新增的点的size标位1而len(p)+1<len(q)的size仍是0(这一点我也不是完全理解了)。然而思路还是用dp的思路,那必然是要算贡献,怎么算呢?分为两种,首先,我们记录当前这个点匹配到的长度,因为一个点对应的集合中有多个子串,所以就可以算出有多少个是符合要求的,匹配长度-min+1(套路),然后乘以这个点在A串中的出现次数。然而这样答案少了,因为如果一个地方可以匹配,那么和它的link也一定能匹配。因此再做一下树dp,从后往前加一加,f[i]表示i的所有儿子中的匹配数量,然后f[i]*size[i]*(max-min+1)。两部分加起来就是答案。
#include<bits/stdc++.h> #define pb push_back #define mp make_pair #define ll long long using namespace std; const int maxn=4e5+100; char s[maxn],s1[maxn]; struct str { int son[26]; int ff,len; }sr[maxn]; int cnt=1,last=1; int siz[maxn],c[maxn],a[maxn]; ll dp[maxn]; ll f[maxn]; ll g[maxn]; void extend(int c) { int p=last,np=++cnt;last=np; sr[np].len=sr[p].len+1;siz[np]=1; while(p&&!sr[p].son[c])sr[p].son[c]=np,p=sr[p].ff; if(!p)sr[np].ff=1; else { int q=sr[p].son[c]; if(sr[p].len+1==sr[q].len)sr[np].ff=q; else { int nq=++cnt; sr[nq]=sr[q]; sr[nq].len=sr[p].len+1; sr[q].ff=sr[np].ff=nq; while(p&&sr[p].son[c]==q)sr[p].son[c]=nq,p=sr[p].ff; } } //siz[np]=1; } int main() { scanf("%s",s+1); scanf("%s",s1+1); int len=strlen(s+1),len1=strlen(s1+1); for(int i=1;i<=len;i++)extend(s[i]-'a'); for(int i=1;i<=cnt;i++)c[sr[i].len]++; for(int i=1;i<=cnt;i++)c[i]+=c[i-1]; for(int i=1;i<=cnt;i++)a[c[sr[i].len]--]=i; for(int i=cnt;i;i--)siz[sr[a[i]].ff]+=siz[a[i]]; int tmp=0,now=1; ll ans=0; for(int i=1;i<=len1;i++) { int v=s1[i]-'a'; if(sr[now].son[v])++tmp,now=sr[now].son[v]; else { while(now&&!sr[now].son[v])now=sr[now].ff; if(!now)now=1,tmp=0; else tmp=sr[now].len+1,now=sr[now].son[v]; } g[now]++; ans+=1ll*(tmp-sr[sr[now].ff].len)*siz[now]; } for(int i=cnt;i;i--)f[sr[a[i]].ff]+=f[a[i]]+g[a[i]]; for(int i=1;i<=cnt;i++)ans+=1ll*(sr[i].len-sr[sr[i].ff].len)*siz[i]*f[i]; cout<<ans<<"\n"; }