Luogu3346 [ZJOI2015]诸神眷顾的幻想乡

https://www.luogu.com.cn/problem/P3346

广义后缀自动机

我们需要把所有两两叶子节点之间的路径丢进广义后缀自动机中,然后计算不同子串个数

观察数据,叶子节点数为\(\le 20\),好像有点小

那么我们暴力枚举每个叶子节点为根的情况,然后处理根与其他叶子节点的路径

认真瞧一瞧,好像是一棵免费的\(Trie\)树,那么我们边\(dfs\)边往广义后缀自动机中丢点就可以了

\(C++ Code:\)

#include<cstdio>
#include<iostream>
#include<cstring>
#define N 100005
#define M 200005
#define SAM_N 4000005
using namespace std;
int n,C,x,y,tot,z,col[N];
int head[N],d[M],nxt[M],rd[N],f[N],one[N];
int cnt=1,last=1;
int t[SAM_N][10],pre[SAM_N],len[SAM_N];
int ins(int c,int last)
{
    if (t[last][c])
    {
        if (len[last]+1==len[t[last][c]])
            return t[last][c]; else
            {
                int p=last;
                int q=t[last][c];
                int g=++cnt;
                len[g]=len[p]+1;
                for (int i=0;i<C;i++)
                    t[g][i]=t[q][i];
                pre[g]=pre[q];
                for (;p&&t[p][c]==q;p=pre[p])
                    t[p][c]=g;
                pre[q]=g;
                return g;
            }
    }
    int p,q;
    int np=++cnt;
    len[np]=len[last]+1;
    for (p=last;p&&!t[p][c];p=pre[p])
        t[p][c]=np;
    if (!p)
        pre[np]=1; else
        {
            q=t[p][c];
            if (len[p]+1==len[q])
                pre[np]=q; else
                {
                    int g=++cnt;
                    len[g]=len[p]+1;
                    for (int i=0;i<C;i++)
                        t[g][i]=t[q][i];
                    pre[g]=pre[q];
                    for (;p&&t[p][c]==q;p=pre[p])
                        t[p][c]=g;
                    pre[q]=pre[np]=g;
                }
        }
    return np;
}
void add(int x,int y)
{
    tot++;
    d[tot]=y;
    nxt[tot]=head[x];
    head[x]=tot;
    rd[y]++;
}
void build(int u,int last)
{
    last=ins(col[u],last);
    for (int i=head[u];i;i=nxt[i])
    {
        int v=d[i];
        if (v==f[u])
            continue;
        f[v]=u;
        build(v,last);
    }
}
int main()
{
    scanf("%d%d",&n,&C);
    for (int i=1;i<=n;i++)
        scanf("%d",&col[i]);
    for (int i=1;i<n;i++)
    {
        scanf("%d%d",&x,&y);
        add(x,y);
        add(y,x);
    }
    for (int i=1;i<=n;i++)
        if (rd[i]==1)
            one[++z]=i;
    for (int i=1;i<=z;i++)
    {
        f[one[i]]=0;
        build(one[i],1);
    }
    long long ans=0;
    for (int i=2;i<=cnt;i++)
        ans+=len[i]-len[pre[i]];
    cout << ans << endl;
    return 0;
}
posted @ 2020-07-23 17:33  GK0328  阅读(97)  评论(0编辑  收藏  举报