Educational Codeforces Round 162 (Rated for Div. 2) E
E:Link
枚举路径两端的颜色 \(k\)。
令 \(g[x]\) 表示满足以下条件的点 \(y\) 数量。
- $ y \in subtree[x]$
- \(col[y] = k\)
- \(y\) 到 \(fa[x]\) 的路径中不存在其他颜色为 \(k\) 的点。
那么
\[\begin{equation}
g[x]=\left\{
\begin{aligned}
1 \quad col[x] = k\\
\sum g[y] \quad col[x] \ne k ,\quad y \in son[x]
\end{aligned}
\right
.
\end{equation}
\]
统计最高点为 \(x\) 的合法路径数。
如果 \(col[x] = k\),\(x\) 只能作为端点之一,贡献等于 \(\sum g[y]\)。
否则 \(x\) 作为路径中转点,对于任意两棵子树 \(y\) 和 \(z\),贡献为 \(g[y] \times g[z]\)。
虚树优化即可。
#include<bits/stdc++.h>
#define rep(i, a, b) for(int i = (a); i <= (b); ++ i)
#define per(i, a, b) for(int i = (a); i >= (b); -- i)
#define pb emplace_back
#define All(X) X.begin(), X.end()
using namespace std;
using ll = long long;
constexpr int N = 2e5 + 5;
vector<int> G[N];
int fa[N][19], dep[N], dfn[N], timestamp;
void dfs(int x) {
dfn[x] = ++ timestamp;
for(auto y : G[x]) {
if(y != fa[x][0]) {
dep[y] = dep[x] + 1;
fa[y][0] = x;
rep(i, 1, 18) fa[y][i] = fa[fa[y][i - 1]][i - 1];
dfs(y);
}
}
}
int lca(int x, int y) {
if(dep[x] < dep[y]) swap(x, y);
per(i, 18, 0) if(dep[fa[x][i]] >= dep[y]) x = fa[x][i];
if(x == y) return x;
per(i, 18, 0) if(fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i];
return fa[x][0];
}
int n, c[N];
vector<int> col[N], H[N];
set<int> se;
ll ans, g[N];
void dp(int x, int k) {
g[x] = 0;
for(int y : H[x]) {
dp(y, k);
}
if(c[x] == k) {
g[x] = 1;
for(int y : H[x]) {
ans += g[y];
}
}
else {
ll t = 0;
for(int y : H[x]) {
ans += g[y] * t;
t += g[y];
g[x] += g[y];
}
}
}
bool cmp(int x, int y) {
return dfn[x] < dfn[y];
}
void calc(vector<int> &a, int k) {
a.pb(1);
sort(All(a), cmp);
int sz = a.size();
rep(i, 1, sz - 1) a.pb(lca(a[i - 1], a[i]));
sort(All(a), cmp);
a.erase(unique(All(a)), a.end());
rep(i, 1, a.size() - 1) H[lca(a[i - 1], a[i])].pb(a[i]);
dp(1, k);
rep(i, 1, a.size() - 1) H[lca(a[i - 1], a[i])].clear();
}
void solve() {
cin >> n;
rep(i, 1, n) cin >> c[i], col[c[i]].pb(i), se.insert(c[i]);
rep(i, 2, n) {
int x, y; cin >> x >> y;
G[x].pb(y);
G[y].pb(x);
}
dfs(dep[1] = 1);
for(int k : se) {
calc(col[k], k);
}
cout << ans << '\n';
rep(i, 1, n) G[i].clear(), col[i].clear();
se.clear();
ans = 0;
timestamp = 0;
}
int main() {
ios::sync_with_stdio(false), cin.tie(nullptr), cout.tie(nullptr);
int T = 1;
cin >> T;
while(T --) solve();
return 0;
}