【LOJ】#2137. 「ZJOI2015」诸神眷顾的幻想乡
我居然到了国赛之前才学习怎么做广义后缀自动机
这个题目……意思是……有20个叶子,肯定一条路径都是任意一个叶子为根,一个从某个点往祖先走的路径
这样的话我们可以按照dfs序,从每个节点的父亲那里的后缀自动机节点再加一个节点
这样只要对于每个后缀自动机的节点统计一下节点长度减去父亲节点长度就好了
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <vector>
#include <random>
#include <ctime>
#define fi first
#define se second
#define pii pair<int,int>
#define mp make_pair
#define pb push_back
#define space putchar(' ')
#define enter putchar('\n')
#define MAXN 100005
//#define ivorysi
using namespace std;
typedef long long int64;
typedef double db;
template<class T>
void read(T &res) {
res = 0;T f = 1;char c = getchar();
while(c < '0' || c > '9') {
if(c == '-') f = -1;
c = getchar();
}
while(c >= '0' && c <= '9') {
res = res * 10 + c - '0';
c = getchar();
}
res *= f;
}
template<class T>
void out(T x) {
if(x < 0) {x = -x;putchar('-');}
if(x >= 10) out(x / 10);
putchar('0' + x % 10);
}
struct node {
int to,next;
}E[MAXN * 2];
int N,C,col[MAXN],sumE,head[MAXN],deg[MAXN];
void add(int u,int v) {
E[++sumE].to = v;
E[sumE].next = head[u];
head[u] = sumE;
}
struct sam {
struct node {
int len,par,nxt[10];
}tr[MAXN * 40];
int tail,root,last;
void Init() {
root = last = ++tail;
}
int expend(int x,int c) {
if(tr[x].nxt[c] && tr[tr[x].nxt[c]].len == tr[x].len + 1) return tr[x].nxt[c];
int nw = ++tail,p;
tr[nw].len = tr[x].len + 1;
for(p = x ; p && !tr[p].nxt[c] ; p = tr[p].par) {
tr[p].nxt[c] = nw;
}
if(!p) tr[nw].par = root;
else {
int q = tr[p].nxt[c];
if(tr[q].len == tr[p].len + 1) tr[nw].par = q;
else {
int cq = ++tail;
tr[cq] = tr[q];
tr[cq].len = tr[p].len + 1;
tr[q].par = tr[nw].par = cq;
for(; p && tr[p].nxt[c] == q ; p = tr[p].par) {
tr[p].nxt[c] = cq;
}
}
}
return nw;
}
int64 Calc() {
int64 res = 0;
for(int i = 1 ; i <= tail ; ++i) {
res += tr[i].len - tr[tr[i].par].len;
}
return res;
}
}SAM;
void dfs(int u,int fa,int p) {
p = SAM.expend(p,col[u]);
for(int i = head[u] ; i ; i = E[i].next) {
int v = E[i].to;
if(v != fa) {
dfs(v,u,p);
}
}
}
void Solve() {
read(N);read(C);
for(int i = 1 ; i <= N ; ++i) read(col[i]);
int u,v;
for(int i = 1 ; i < N ; ++i) {
read(u);read(v);
add(u,v);add(v,u);
++deg[v];++deg[u];
}
SAM.Init();
for(int i = 1 ; i <= N ; ++i) {
if(deg[i] == 1) dfs(i,0,SAM.root);
}
out(SAM.Calc());enter;
}
int main() {
#ifdef ivorysi
freopen("f3.in","r",stdin);
#endif
Solve();
}