BZOJ4566: [Haoi2016]找相同字符

 

4566: [Haoi2016]找相同字符

Time Limit: 20 Sec  Memory Limit: 256 MB
Submit: 1354  Solved: 795
[Submit][Status][Discuss]

Description

给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两
个子串中有一个位置不同。

 

Input

两行,两个字符串s1,s2,长度分别为n1,n2。1 <=n1, n2<= 200000,字符串中只有小写字母

 

Output

输出一个整数表示答案

 

Sample Input

aabb
bbaa

Sample Output

10

HINT

 

Source

 
[Submit][Status][Discuss]

 

 

题解:这种题首先很套路,只有两个串肯定只给一个建后缀自动机一个去在上面跑。跑的方式就是跟AC自动机那样,有就走否则就跳后缀Link。然后对每个匹配到的位置g[now]++。由于这题也需要拓扑排序,这里就是SAM的性质,sr[sr[a].link].len<sr[a].len,就直接用类似于SA中的基数排序就可以排出拓扑排序。然后就从后往前就可以得到每个状态的出现次数,不需要建图后树dp。这里要注意的是,我们在extend的时候,将新增的点的size标位1而len(p)+1<len(q)的size仍是0(这一点我也不是完全理解了)。然而思路还是用dp的思路,那必然是要算贡献,怎么算呢?分为两种,首先,我们记录当前这个点匹配到的长度,因为一个点对应的集合中有多个子串,所以就可以算出有多少个是符合要求的,匹配长度-min+1(套路),然后乘以这个点在A串中的出现次数。然而这样答案少了,因为如果一个地方可以匹配,那么和它的link也一定能匹配。因此再做一下树dp,从后往前加一加,f[i]表示i的所有儿子中的匹配数量,然后f[i]*size[i]*(max-min+1)。两部分加起来就是答案。

#include<bits/stdc++.h>
#define pb push_back
#define mp make_pair
#define ll long long
using namespace std;
const int maxn=4e5+100;
char s[maxn],s1[maxn];
struct str
{
    int son[26];
    int ff,len;
}sr[maxn];
int cnt=1,last=1;
int siz[maxn],c[maxn],a[maxn];
ll dp[maxn];
ll f[maxn];
ll g[maxn];
void extend(int c)
{
   int p=last,np=++cnt;last=np;
   sr[np].len=sr[p].len+1;siz[np]=1;
   while(p&&!sr[p].son[c])sr[p].son[c]=np,p=sr[p].ff;
   if(!p)sr[np].ff=1;
   else
   {
       int q=sr[p].son[c];
       if(sr[p].len+1==sr[q].len)sr[np].ff=q;
       else
       {
           int nq=++cnt;
           sr[nq]=sr[q];
           sr[nq].len=sr[p].len+1;
           sr[q].ff=sr[np].ff=nq;
           while(p&&sr[p].son[c]==q)sr[p].son[c]=nq,p=sr[p].ff;
       }
   }
    //siz[np]=1;
}
int main()
{
    scanf("%s",s+1);
    scanf("%s",s1+1);
    int len=strlen(s+1),len1=strlen(s1+1);
    for(int i=1;i<=len;i++)extend(s[i]-'a');
    for(int i=1;i<=cnt;i++)c[sr[i].len]++;
    for(int i=1;i<=cnt;i++)c[i]+=c[i-1];
    for(int i=1;i<=cnt;i++)a[c[sr[i].len]--]=i;
    for(int i=cnt;i;i--)siz[sr[a[i]].ff]+=siz[a[i]];
    int tmp=0,now=1;
    ll ans=0;
    for(int i=1;i<=len1;i++)
    {
        int v=s1[i]-'a';
        if(sr[now].son[v])++tmp,now=sr[now].son[v];
        else
        {
            while(now&&!sr[now].son[v])now=sr[now].ff;
            if(!now)now=1,tmp=0;
            else tmp=sr[now].len+1,now=sr[now].son[v];
        }
        g[now]++;
        ans+=1ll*(tmp-sr[sr[now].ff].len)*siz[now];
    }
    for(int i=cnt;i;i--)f[sr[a[i]].ff]+=f[a[i]]+g[a[i]];
    for(int i=1;i<=cnt;i++)ans+=1ll*(sr[i].len-sr[sr[i].ff].len)*siz[i]*f[i];
    cout<<ans<<"\n";
}

  

posted @ 2019-02-21 21:35  Twilight7  阅读(177)  评论(0编辑  收藏  举报