HDU5343 MZL's Circle Zhou(SAM+记忆化搜索)

Problem Description
MZL's Circle Zhou is good at solving some counting problems. One day, he comes up with a counting problem:
You are given two strings a,b which consist of only lowercase English letters. You can subtract a substring x (maybe empty) from string a and a substring y (also maybe empty) from string b, and then connect them as x+y with x at the front and y at the back. In this way, a series of new strings can be obtained.
The question is how many different new strings can be obtained in this way.
Two strings are different, if and only if they have different lengths or there exists an integer i such that the two strings have different characters at position i.
 

 

Input
The first line of the input is a single integer T (T5), indicating the number of testcases. 
For each test case, there are two lines, the first line is string a, and the second line is string b1<=|a|,|b|<=90000.
 

 

Output
For each test case, output one line, a single integer indicating the answer.
 

 

Sample Input
2 acbcc cccabc bbbabbababbababbaaaabbbbabbaaaabaabbabbabbbaaabaab abbaabbabbaaaabbbaababbabbabababaaaaabbaabbaabbaab
 

 

Sample Output
135 557539
 

 

Author
SXYZ
 

 

Source
2015 Multi-University Training Contest 5
 
题解:题目意思给你两个字符串,然后从第一个字符串里面取出一个子串X,从第二个字符串里面取出一个子串Y,两个拼接在一起组成新的字符串,其中X、Y都可以是空串,问有多少个这样不同的串。
思路:考虑X+Y为什么会有重复的。举个例子ababc,可以由ab+abc或则a+babc组成。重复的原因在于在第一个串中可以找到ab,也可以找到a,而如果a可以构成这个拼接串,那么ab也构成这个拼接串。所以说,为了避免重复,我们可以在第一个串中找最长的点,即走到某个点x,然后这个点不能走到字符'a',那么对于字符'a'来说,x这个点就是最长的,在另一个串中找'a'开头的子串个数,这些就是点x的可以匹配到的个数。(用记忆化搜索)

 

参考代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
#define RI register int
const int maxn=1e5+10;
char s1[maxn],s2[maxn];
struct SAM{
    int last,tot,nxt[maxn<<1][27],fa[maxn<<1],l[maxn<<1];
    inline void Init()
    {
        last=tot=1;
        memset(nxt[tot],0,sizeof(nxt[tot]));
        l[tot]=fa[tot]=0;
    }
    inline int NewNode()
    {
        ++tot;
        memset(nxt[tot],0,sizeof(nxt[tot]));
        l[tot]=fa[tot]=0;
        return tot;
    }
    inline void Add(int c)
    {
        int np=NewNode(),p=last;
        last=np;l[np]=l[p]+1;
        while(p&&!nxt[p][c]) nxt[p][c]=np,p=fa[p];
        if(!p) fa[np]=1;
        else
        {
            int q=nxt[p][c];
            if(l[q]==l[p]+1) fa[np]=q;
            else
            {
                int nq=NewNode();
                memcpy(nxt[nq],nxt[q],sizeof(nxt[q]));
                fa[nq]=fa[q];
                l[nq]=l[p]+1;
                fa[q]=fa[np]=nq;
                while(p&&nxt[p][c]==q) nxt[p][c]=nq,p=fa[p];
            }
        }
    }
} sam1,sam2;

int T;
ull dp1[maxn<<1],dp2[maxn<<1];
inline ull dfs2(int u)
{
    if(!u) return 0;
    if(dp2[u]) return dp2[u];
    ull res=1;
    for(int i=0;i<26;++i)
    {
        int nt=sam2.nxt[u][i];
        if(nt) res+=dfs2(nt);    
    }
    return dp2[u]=res;    
}
inline ull dfs(int u)
{
    if(dp1[u]) return dp1[u];
    ull res=1;
    for(int i=0;i<26;++i)
    {
        int nt=sam1.nxt[u][i];
        if(nt) res+=dfs(nt);
        else res+=dfs2(sam2.nxt[1][i]);
    }
    return dp1[u]=res;
}

int main()
{
    scanf("%d",&T);
    while(T--)
    {
        sam1.Init();sam2.Init();
        memset(dp1,0,sizeof dp1);
        memset(dp2,0,sizeof dp2);
        
        scanf("%s%s",s1,s2);
        for(int i=0;s1[i];++i) sam1.Add(s1[i]-'a');
        for(int i=0;s2[i];++i) sam2.Add(s2[i]-'a');
        
        printf("%I64u\n",dfs(1));
    }

    return 0;    
}
View Code

 

posted @ 2019-08-26 23:50  StarHai  阅读(305)  评论(0编辑  收藏  举报