【bzoj3926】[Zjoi2015]诸神眷顾的幻想乡

*题目描述:
幽香是全幻想乡里最受人欢迎的萌妹子,这天,是幽香的2600岁生日,无数幽香的粉丝到了幽香家门前的太阳花田上来为幽香庆祝生日。
粉丝们非常热情,自发组织表演了一系列节目给幽香看。幽香当然也非常高兴啦。
这时幽香发现了一件非常有趣的事情,太阳花田有n块空地。在过去,幽香为了方便,在这n块空地之间修建了n-1条边将它们连通起来。也就是说,这n块空地形成了一个树的结构。
有n个粉丝们来到了太阳花田上。为了表达对幽香生日的祝贺,他们选择了c中颜色的衣服,每种颜色恰好可以用一个0到c-1之间的整数来表示。并且每个人都站在一个空地上,每个空地上也只有一个人。这样整个太阳花田就花花绿绿了。幽香看到了,感觉也非常开心。
粉丝们策划的一个节目是这样的,选中两个粉丝A和B(A和B可以相同),然后A所在的空地到B所在的空地的路径上的粉丝依次跳起来(包括端点),幽香就能看到一个长度为A到B之间路径上的所有粉丝的数目(包括A和B)的颜色序列。一开始大家打算让人一两个粉丝(注意:A,B和B,A是不同的,他们形成的序列刚好相反,比如红绿蓝和蓝绿红)都来一次,但是有人指出这样可能会出现一些一模一样的颜色序列,会导致审美疲劳。
于是他们想要问题,在这个树上,一共有多少可能的不同的颜色序列(子串)幽香可以看到呢?
太阳花田的结构比较特殊,只与一个空地相邻的空地数量不超过20个。
*输入:
第一行两个正整数n,c。表示空地数量和颜色数量。
第二行有n个0到c-1之间,由空格隔开的整数,依次表示第i块空地上的粉丝的衣服颜色。(这里我们按照节点标号从小到大的顺序依次给出每块空地上粉丝的衣服颜色)。
接下来n-1行,每行两个正整数u,v,表示有一条连接空地u和空地v的边。
*输出:
一行,输出一个整数,表示答案。
*样例输入:
7 3
0 2 1 2 1 0 0
1 2
3 4
3 5
4 6
5 7
2 5
*样例输出:
30
*提示:
对于所有数据,1<=n<=100000, 1<=c<=10。
对于15%的数据,n<=2000。
另有5%的数据,所有空地都至多与两个空地相邻。
另有5%的数据,除一块空地与三个空地相邻外,其他空地都分别至多与两个空地相邻。
另有5%的数据,除某两块空地与三个空地相邻外,其他空地都分别至多与两个空地相邻
*题解:
广义后缀树(实际上是后缀自动机)。由于叶子节点不会很多,所以我们从每个叶子节点开始dfs,将路径上的字符串插入后缀自动机。这样每个两点间的路径串必然是后缀自动机的某个子串。于是乎问题就转化为了求广义后缀自动机上的不同子串的个数。
*代码:

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>

#ifdef WIN32
    #define LL "%I64d"
#else
    #define LL "%lld"
#endif

#ifdef CT
    #define debug(...) printf(__VA_ARGS__)
    #define setfile() 
#else
    #define debug(...)
    #define filename ""
    #define setfile() freopen(filename".in", "r", stdin); freopen(filename".out", "w", stdout);
#endif

#define R register
#define getc() (S == T && (T = (S = B) + fread(B, 1, 1 << 15, stdin), S == T) ? EOF : *S++)
#define dmax(_a, _b) ((_a) > (_b) ? (_a) : (_b))
#define dmin(_a, _b) ((_a) < (_b) ? (_a) : (_b))
#define cmax(_a, _b) (_a < (_b) ? _a = (_b) : 0)
#define cmin(_a, _b) (_a > (_b) ? _a = (_b) : 0)
char B[1 << 15], *S = B, *T = B;
inline int FastIn()
{
    R char ch; R int cnt = 0; R bool minus = 0;
    while (ch = getc(), (ch < '0' || ch > '9') && ch != '-') ;
    ch == '-' ? minus = 1 : cnt = ch - '0';
    while (ch = getc(), ch >= '0' && ch <= '9') cnt = cnt * 10 + ch - '0';
    return minus ? -cnt : cnt;
}
#define maxn 200010
#define maxm 4000010
struct sam
{
    sam *fa, *next[10];
    int val;
}mem[maxm], *tot = mem, *pos[maxn];
inline sam *extend(R sam *p, R int c)
{
    //if (p -> next[c] && p -> next[c] -> val == p -> val + 1)
    //  return p -> next[c];
    if(p->next[c])
    {
        R sam *q = p->next[c];
        if (q -> val == p -> val + 1)
            return q;
        else
        {
            R sam *nq = ++tot;
            memcpy(nq -> next, q -> next, sizeof nq -> next);
            nq -> val = p -> val + 1;
            nq -> fa = q -> fa;
            q -> fa = nq;
            for ( ; p && p -> next[c] == q; p = p -> fa)
                p -> next[c] = nq;
            return nq;
        }
    }
    R sam *np = ++tot;
    np -> val = p -> val + 1;
    for ( ; p && !p -> next[c]; p = p -> fa) p -> next[c] = np;
    if (!p)
        np -> fa = mem;
    else
    {
        R sam *q = p -> next[c];
        if (q -> val == p -> val + 1)
            np -> fa = q;
        else
        {
            R sam *nq = ++tot;
            memcpy(nq -> next, q -> next, sizeof nq -> next);
            nq -> val = p -> val + 1;
            nq -> fa = q -> fa;
            q -> fa = np -> fa = nq;
            for ( ; p && p -> next[c] == q; p = p -> fa)
                p -> next[c] = nq;
        }
    }
    return np;
}
struct Edge
{
    Edge *next;
    int to;
}*last[maxn], e[maxm], *ecnt = e;
int deg[maxn], col[maxn];
inline void link(R int a, R int b)
{
    ++deg[a]; ++deg[b];
    *++ecnt = (Edge) {last[a], b}; last[a] = ecnt;
    *++ecnt = (Edge) {last[b], a}; last[b] = ecnt;
}
void dfs(R int x, R int fa)
{
    pos[x] = extend(pos[fa], col[x]);
    for (R Edge *iter = last[x]; iter; iter = iter -> next)
        if (iter -> to != fa)
            dfs(iter -> to, x);
}
int main()
{
//  setfile();
    R int n = FastIn(), c = FastIn();
    for (R int i = 1; i <= n; ++i)
        col[i] = FastIn();
    for (R int i = 1; i < n; ++i)
        link(FastIn(), FastIn());
    pos[0] = mem;
    for (R int i = 1; i <= n; ++i)
        if (deg[i] == 1)
            dfs(i, 0);
    R long long ans = 0;
    for (R sam *iter = mem + 1; iter <= tot; ++iter)
        ans += iter -> val - iter -> fa -> val;
    printf("%lld\n", ans );
    return 0;
}
/*
7 3
0 2 1 2 1 0 0 
1 2
3 4
3 5
4 6
5 7
2 5
*/
posted @ 2016-06-17 09:46  cot  阅读(141)  评论(0编辑  收藏  举报