bzoj 3926 转换+广义后缀自动机

思路:重点在于叶子节点只有20个,我们把叶子节点提到根,把20个trie图插入后缀自动机,然后就是算有多少个本质不同的字串。

#include<bits/stdc++.h>
#define LL long long
#define fi first
#define se second
#define mk make_pair
#define PII pair<int, int>
#define PLI pair<LL, int>
#define ull unsigned long long
using namespace std;

const int N = 2e6 + 7;
const int inf = 0x3f3f3f3f;
const LL INF = 0x3f3f3f3f3f3f3f3f;
const int mod = 1e9 + 7;
const double eps = 1e-8;

int n, m, c, tot, s[N], deg[N], head[N];
struct Edge {
    int to, nx;
} edge[N];

void add(int u, int v) {
    edge[tot].to = v;
    edge[tot].nx = head[u];
    head[u] = tot++;
}

struct SuffixAutomaton {
    int last, cur, cnt, ch[N<<1][10], id[N<<1], fa[N<<1], dis[N<<1], sz[N<<1], c[N];
    SuffixAutomaton() {cur = cnt = 1;}
    void init() {
        for(int i = 1; i <= cnt; i++) {
            memset(ch[i], 0, sizeof(ch[i]));
            sz[i] = c[i] = dis[i] = fa[i] = 0;
        }
        cur = cnt = 1;
    }
    int extend(int p, int c) {
        cur = ++cnt; dis[cur] = dis[p]+1;
        for(; p && !ch[p][c]; p = fa[p]) ch[p][c] = cur;
        if(!p) fa[cur] = 1;
        else {
            int q = ch[p][c];
            if(dis[q] == dis[p]+1) fa[cur] = q;
            else {
                int nt = ++cnt; dis[nt] = dis[p]+1;
                memcpy(ch[nt], ch[q], sizeof(ch[q]));
                fa[nt] = fa[q]; fa[q] = fa[cur] = nt;
                for(; ch[p][c]==q; p=fa[p]) ch[p][c] = nt;
            }
        }
        sz[cur] = 1;
        return cur;
    }
    void getSize(int n) {
        for(int i = 1; i <= cnt; i++) c[dis[i]]++;
        for(int i = 1; i <= n; i++) c[i] += c[i-1];
        for(int i = cnt; i >= 1; i--) id[c[dis[i]]--] = i;
    }
    void dfs(int u, int fa, int last) {
        int cur = extend(last, s[u]);
        for(int i = head[u]; ~i; i = edge[i].nx) {
            int v = edge[i].to;
            if(v != fa) dfs(v, u, cur);
        }
    }
    void solve() {
        memset(head, -1, sizeof(head));
        scanf("%d%d", &n, &c);
        for(int i = 1; i <= n; i++) scanf("%d", &s[i]);
        for(int i = 1; i < n; i++) {
            int u, v; scanf("%d%d", &u, &v);
            add(u, v); add(v, u);
            deg[u]++; deg[v]++;
        }
        for(int i = 1; i <= n; i++)
            if(deg[i] == 1) dfs(i, 0, 1);
        LL ans = 0;
        for(int i = 2; i <= cnt; i++)
            ans += dis[i] - dis[fa[i]];
        printf("%lld\n", ans);
    }
} sam;

int main() {
    sam.solve();
    return 0;
}

/*
*/

 

posted @ 2018-10-21 16:21  NotNight  阅读(103)  评论(0编辑  收藏  举报