hdu5343 后缀自动机+dp

给定两个串,分别截取字串X和Y,连接组成X+Y,求不同的X+Y的方案数。

对于X+Y,如果重复的部分其实就是从同一个X+Y的某个地方断开弄成不同的X和Y,那么只要使得X和X+Y匹配得最长就行了。

因此,对两个字符串分别建立后缀自动机A和B,在A中找字串X,当X的末尾不能接某个字符c时,在B中找以c为开头的所有字串。

注意字串的是n^2个,所以不管怎样都不能以暴力遍历自动机的方式来统计,而由于SAM是DAG,所以实际上是在两个DAG上进行dp。

 

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#define REP(i,a,b) for(int i=a;i<=b;i++)
#define MS0(a) memset(a,0,sizeof(a))

using namespace std;

typedef unsigned long long ll;
const int maxn=1000100;
const int INF=1e9+10;

char s[maxn],t[maxn];
ll dp1[maxn],dp2[maxn];

struct SAM
{
    int ch[maxn][26];
    int pre[maxn],step[maxn];
    int last,tot;
    void init()
    {
        last=tot=0;
        memset(ch[0],-1,sizeof(ch[0]));
        pre[0]=-1;
        step[0]=0;
    }
    void add(int c)
    {
        c-='a';
        int p=last,np=++tot;
        step[np]=step[p]+1;
        memset(ch[np],-1,sizeof(ch[np]));
        while(~p&&ch[p][c]==-1) ch[p][c]=np,p=pre[p];
        if(p==-1) pre[np]=0;
        else{
            int q=ch[p][c];
            if(step[q]!=step[p]+1){
                int nq=++tot;
                step[nq]=step[p]+1;
                memcpy(ch[nq],ch[q],sizeof(ch[q]));
                pre[nq]=pre[q];
                pre[q]=pre[np]=nq;
                while(~p&&ch[p][c]==q) ch[p][c]=nq,p=pre[p];
            }
            else pre[np]=q;
        }
        last=np;
    }
};SAM A,B;

ll dfs2(int u)
{
    if(u==-1) return 0;
    ll &res=dp2[u];
    if(~res) return res;
    res=1;
    REP(c,0,25) res+=dfs2(B.ch[u][c]);
    return res;
}

ll dfs1(int u)
{
    ll &res=dp1[u];
    if(~res) return res;
    res=1;
    REP(c,0,25){
        if(~A.ch[u][c]) res+=dfs1(A.ch[u][c]);
        else res+=dfs2(B.ch[0][c]);
    }
    return res;
}

void solve()
{
    A.init();B.init();
    int ls=strlen(s),lt=strlen(t);
    REP(i,0,ls-1) A.add(s[i]);
    REP(i,0,lt-1) B.add(t[i]);
    memset(dp1,-1,sizeof(dp1));
    memset(dp2,-1,sizeof(dp2));
    printf("%I64u\n",dfs1(0));
}

int main()
{
    freopen("in.txt","r",stdin);
    int T;cin>>T;
    while(T--){
        scanf("%s%s",s,t);
        solve();
    }
    return 0;
}
View Code

 

posted @ 2016-05-04 11:27  __560  阅读(790)  评论(0编辑  收藏  举报