真的是好题啊~不过在说做法之前先强调几个自己总是掉的坑点。
-
更新节点永远记不住往上跳\(p = fa[p]\)
-
新建节点永远记不住\(len[y] = len[p] + 1\)
-
总之就是各种各样的智障错误啦!
来说这个题目。这个题目厉害就厉害在,你要想到怎么用只有\(20\)个叶子节点这个信息。对于任意一条树上的路径,它一定对应于以某个叶节点为根的树中,一条根节点发出的链上的后缀。既然叶节点只有\(20\)个,我们要做的就是换\(20\)次根就好了。由于是要搞根节点发出链上的后缀,所以我们事实上可以把这颗树当做一个\(Trie\)来使用。这样的话,每次把整颗\(Trie\)插入\(SAM\)中(新学的骚操作\(XD\)),最终求出所有节点的\(len[i] - len[fa[i]]\),就是不同子串个数了。
记得判重。把整颗\(Trie\)插进去的操作真的是学到了。
#include <bits/stdc++.h>
using namespace std;
const int N = 4000010;
struct Suffix_Auto {
int fa[N], len[N], ch[N][10];
int rt, las, cnt;
void Init () {
rt = las = cnt = 1;
memset (ch, 0, sizeof (ch));
memset (fa, 0, sizeof (fa));
memset (len, 0, sizeof (len));
}
int extend (int p, int c) {
if (ch[p][c] != 0) {
//已有节点
int x = ch[p][c];
if (len[x] == len[p] + 1) {
las = x;
} else {
int y = ++cnt; las = y;
len[y] = len[p] + 1;
fa[y] = fa[x];
fa[x] = y; //注意顺序
memcpy (ch[y], ch[x], sizeof (ch[x]));
while (p != 0 && ch[p][c] == x) {
ch[p][c] = y;
p = fa[p];
}
}
} else {
//单纯的插入
int q = ++cnt; las = q;
len[q] = len[p] + 1;
while (p != 0 && ch[p][c] == 0) {
ch[p][c] = q;
p = fa[p]; // 不要忘了这句啊
}
if (p == 0) {
fa[q] = rt;
} else {
int x = ch[p][c];
if (len[x] == len[p] + 1) {
fa[q] = x;
} else {
int y = ++cnt;
fa[y] = fa[x];
fa[x] = fa[q] = y;
len[y] = len[p] + 1;
memcpy (ch[y], ch[x], sizeof (ch[x]));
while (p != 0 && ch[p][c] == x) {
ch[p][c] = y;
p = fa[p];
}
}
}
}
return las;
}
long long get_ans () {
long long ans = 0;
for (int i = 1; i <= cnt; ++i) {
ans += len[i] - len[fa[i]];
}
return ans;
}
}sam;
int n, m, u, v, cnt, d[N], col[N], head[N];
struct edge {
int nxt, to;
}e[N << 1];
void add_edge (int from, int to){
e[++cnt].nxt = head[from];
e[cnt].to = to;
head[from] = cnt;
}
void get_tree (int u, int fa, int p) {
p = sam.extend (p, col[u]); //在p后面接上节点col[u]
for (int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if (v != fa) {
get_tree (v, u, p);
}
}
}
void add_len (int u, int v) {
add_edge (u, v);
add_edge (v, u);
}
int main () {
cin >> n >> m;
sam.Init ();
for (int i = 1; i <= n; ++i) {
cin >> col[i];
}
for (int i = 1; i <= n - 1; ++i) {
cin >> u >> v;
add_len (u, v);
++d[u], ++d[v];
}
for (int i = 1; i <= n; ++i) {
if (d[i] == 1) {
get_tree (i, 0, sam.rt);
//以i为树根跑一遍
}
}
cout << sam.get_ans () << endl;
}