P10060 [SNOI2024] 树 V 图 题解
手玩样例可以发现,把树划分为若干极大同 $f$ 连通块,
若存在 $i$ 满足极大 $f=i$ 连通块不唯一,则无解。证明比较平凡。
另外,若存在 $i$ 满足没有 $f=i$ 的点,也无解。证明比较平凡。
把每个唯一的极大 $f=i$ 连通块缩起来,建出一棵新树 $T$,
设 $f_{i,j}$ 表示钦定 $a_i=j$,然后对 $T$ 上 $i$ 子树中每个 $f=k$ 连通块选定 $a_k$ 的方案数,
可以发现这里 $j$ 一定在 $f=i$ 连通块中,证明比较平凡。
考虑转移,对于符合 $f_{i,j}$ 要求的方案,$i$ 的每个孩子的子树的方案是独立的,
于是 $f_{i,j}$ 即为 $i$ 的每个孩子的子树符合要求的方案数之积,即
$$ f_{i,j}=\prod_{v\in\text{son}(i)}\sum f_{v,k} $$
其中要求 $a_i=j,a_v=k$ 合法,考虑如何判断这个条件,
发现只需判断 $f=i,f=v$ 连通块交界处的两个点的 $f$ 值在这种情况下是否正确,
预处理相邻连通块交界处的点、原树上任意两点间的距离即可。
#include <vector>
#include <cstdio>
#include <cstring>
#define M 998244353
#define int long long
using namespace std;
struct T
{
struct E
{
int v, t;
} e[6050];
int c, h[3050];
void A(int u, int v) { e[++c] = {v, h[u]}, h[u] = c; }
} X, Y;
vector<int> V[3050];
bool b[3050];
int T, n, k, u[3050], v[3050], a[3050], f[3050][3050], d[3050][3050], p[3050][3050];
void D1(int u, int k, int S)
{
for (int i = X.h[u], v; i; i = X.e[i].t)
if ((v = X.e[i].v) != k)
d[S][v] = d[S][u] + 1, D1(v, u, S);
}
void D2(int u, int k, int S)
{
b[u] = 1;
V[S].push_back(u);
for (int i = X.h[u], v; i; i = X.e[i].t)
if ((v = X.e[i].v) != k && a[v] == S)
D2(v, u, S);
}
bool C(int x, int y, int k) { return d[x][k] == d[y][k] ? a[x] < a[y] : d[x][k] < d[y][k]; }
void Q(int u, int k)
{
for (auto i : V[u])
f[u][i] = 1;
for (int i = Y.h[u], v; i; i = Y.e[i].t)
if ((v = Y.e[i].v) != k)
{
Q(v, u);
for (auto i : V[u])
{
if (!f[u][i])
continue;
int q = 0;
for (auto j : V[v])
if (C(i, j, p[u][v]) && C(j, i, p[v][u]))
q = (q + f[v][j]) % M;
f[u][i] = f[u][i] * q % M;
}
}
}
signed main()
{
scanf("%lld", &T);
while (T--)
{
for (int i = 1; i <= 3000; ++i)
V[i].clear();
X.c = Y.c = 0;
memset(X.e, 0, sizeof X.e);
memset(Y.e, 0, sizeof Y.e);
memset(X.h, 0, sizeof X.h);
memset(Y.h, 0, sizeof Y.h);
memset(b, 0, sizeof b);
memset(u, 0, sizeof u);
memset(v, 0, sizeof v);
memset(a, 0, sizeof a);
memset(f, 0, sizeof f);
scanf("%lld%lld", &n, &k);
for (int i = 1; i < n; ++i)
scanf("%lld%lld", u + i, v + i), X.A(u[i], v[i]), X.A(v[i], u[i]);
for (int i = 1; i <= n; ++i)
scanf("%lld", a + i);
bool F = 1;
for (int u = 1; u <= k; ++u)
{
bool f = 0;
for (int i = 1; i <= n; ++i)
if (a[i] == u)
{
D2(i, 0, u);
f = 1;
break;
}
if (!f)
{
puts("0");
F = 0;
break;
}
}
if (!F)
continue;
for (int i = 1; i <= n; ++i)
if (!b[i])
{
puts("0");
F = 0;
break;
}
if (!F)
continue;
for (int i = 1; i <= n; ++i)
d[i][i] = 0, D1(i, 0, i);
for (int i = 1; i < n; ++i)
if (a[u[i]] != a[v[i]])
{
Y.A(a[u[i]], a[v[i]]);
Y.A(a[v[i]], a[u[i]]);
p[a[u[i]]][a[v[i]]] = u[i];
p[a[v[i]]][a[u[i]]] = v[i];
}
Q(1, 0);
int q = 0;
for (auto i : V[1])
q = (q + f[1][i]) % M;
printf("%lld\n", q);
}
return 0;
}