BZOJ 4566 [Haoi2016]找相同字符 ——广义后缀自动机

建立广义后缀自动机。

然后统计子树中的siz,需要分开统计

然后对(l[i]-l[fa[i]])*siz[i][0]*siz[i][1]求和即可。

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
 
#define F(i,j,k) for (int i=j;i<=k;++i)
#define D(i,j,k) for (int i=j;i>=k;--i)
#define ll long long
#define maxn 800005
 
struct Suffix_Auto{
    int go[maxn][26],l[maxn],siz[maxn],fa[maxn];
    int last,cnt,v[maxn],q[maxn],rit[maxn][2];
    char s[maxn];
    void init()
    {
        last=cnt=1;
        memset(go,0,sizeof go);
    }
    void add(int x,int id)
    {
        int p=last,q;
        if (q=go[p][x])
        {
            if (l[q]==l[p]+1) last=q;
            else
            {
                int nq=++cnt;
                l[nq]=l[p]+1;
                memcpy(go[nq],go[q],sizeof go[q]);
                fa[nq]=fa[q];
                fa[q]=nq;
                for (;p&&go[p][x]==q;p=fa[p]) go[p][x]=nq;
                last=nq;
            }
        }
        else
        {
            int np=++cnt; l[np]=l[p]+1;
            for (;p&&!go[p][x];p=fa[p]) go[p][x]=np;
            if (!p) fa[np]=1;
            else
            {
                q=go[p][x];
                if (l[q]==l[p]+1) fa[np]=q;
                else
                {
                    int nq=++cnt;
                    l[nq]=l[p]+1;
                    memcpy(go[nq],go[q],sizeof go[q]);
                    fa[nq]=fa[q];
                    fa[q]=fa[np]=nq;
                    for (;p&&go[p][x]==q;p=fa[p]) go[p][x]=nq;
                }
            }
            last=np;
        }
        rit[last][id]++;
    }
    void build()
    {
        int n1;
        init();
        scanf("%s",s+1);
        n1=strlen(s+1);
        F(i,1,n1) add(s[i]-'a',0);
        last=1;
        scanf("%s",s+1);
        n1=strlen(s+1);
        F(i,1,n1) add(s[i]-'a',1);
    }
    void radix()
    {
        F(i,1,cnt) v[l[i]]++;
        F(i,1,cnt) v[i]+=v[i-1];
        D(i,cnt,1) q[v[l[i]]--]=i;
        D(i,cnt,1)
        {
            rit[fa[q[i]]][0]+=rit[q[i]][0];
            rit[fa[q[i]]][1]+=rit[q[i]][1];
        }
        ll ans=0;
        F(i,1,cnt)
            ans+=(ll)rit[i][0]*rit[i][1]*(l[i]-l[fa[i]]);
        printf("%lld\n",ans);
    }
    void solve()
    {
        build();
        radix();
    }
}sam;
 
int main(){sam.solve();}

  

posted @ 2017-03-01 23:08  SfailSth  阅读(251)  评论(0编辑  收藏  举报