Educational Codeforces Round 162 (Rated for Div. 2) E

E:Link
枚举路径两端的颜色 \(k\)

\(g[x]\) 表示满足以下条件的点 \(y\) 数量。

  • $ y \in subtree[x]$
  • \(col[y] = k\)
  • \(y\)\(fa[x]\) 的路径中不存在其他颜色为 \(k\) 的点。

那么

\[\begin{equation} g[x]=\left\{ \begin{aligned} 1 \quad col[x] = k\\ \sum g[y] \quad col[x] \ne k ,\quad y \in son[x] \end{aligned} \right . \end{equation} \]

统计最高点为 \(x\) 的合法路径数。
如果 \(col[x] = k\)\(x\) 只能作为端点之一,贡献等于 \(\sum g[y]\)
否则 \(x\) 作为路径中转点,对于任意两棵子树 \(y\)\(z\),贡献为 \(g[y] \times g[z]\)

虚树优化即可。

#include<bits/stdc++.h>
#define rep(i, a, b) for(int i = (a); i <= (b); ++ i)
#define per(i, a, b) for(int i = (a); i >= (b); -- i)
#define pb emplace_back
#define All(X) X.begin(), X.end()
using namespace std;
using ll = long long;

constexpr int N = 2e5 + 5;

vector<int> G[N];
int fa[N][19], dep[N], dfn[N], timestamp;

void dfs(int x) {
	dfn[x] = ++ timestamp;
	for(auto y : G[x]) {
		if(y != fa[x][0]) {
			dep[y] = dep[x] + 1;
			fa[y][0] = x;
			rep(i, 1, 18) fa[y][i] = fa[fa[y][i - 1]][i - 1];
			dfs(y);
		}
	}
}

int lca(int x, int y) {
	if(dep[x] < dep[y]) swap(x, y);
	per(i, 18, 0) if(dep[fa[x][i]] >= dep[y]) x = fa[x][i];
	if(x == y) return x;
	per(i, 18, 0) if(fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i];
	return fa[x][0];
}

int n, c[N];
vector<int> col[N], H[N];
set<int> se;

ll ans, g[N];

void dp(int x, int k) {
	g[x] = 0;
	for(int y : H[x]) {
		dp(y, k);
	}
	if(c[x] == k) {
		g[x] = 1;
		for(int y : H[x]) {
			ans += g[y];
		}
	}
	else {
		ll t = 0;
		for(int y : H[x]) {
			ans += g[y] * t;
			t += g[y];
			g[x] += g[y];
		}
	}
}

bool cmp(int x, int y) {
	return dfn[x] < dfn[y];
}

void calc(vector<int> &a, int k) {
	a.pb(1);
	sort(All(a), cmp);
	int sz = a.size();
	rep(i, 1, sz - 1) a.pb(lca(a[i - 1], a[i]));
	sort(All(a), cmp);
	a.erase(unique(All(a)), a.end());
	rep(i, 1, a.size() - 1) H[lca(a[i - 1], a[i])].pb(a[i]);
	dp(1, k);
	rep(i, 1, a.size() - 1) H[lca(a[i - 1], a[i])].clear();
}

void solve() {
	cin >> n;
	rep(i, 1, n) cin >> c[i], col[c[i]].pb(i), se.insert(c[i]);
	rep(i, 2, n) {
		int x, y; cin >> x >> y;
		G[x].pb(y);
		G[y].pb(x);
	}
	dfs(dep[1] = 1);
	for(int k : se) {
		calc(col[k], k);
	}
	cout << ans << '\n';
	rep(i, 1, n) G[i].clear(), col[i].clear();
	se.clear();
	ans = 0;
	timestamp = 0;
}

int main() {
	ios::sync_with_stdio(false), cin.tie(nullptr), cout.tie(nullptr);
	int T = 1;
	cin >> T;
	while(T --) solve();
	return 0;
}
posted @ 2024-02-24 03:04  Lu_xZ  阅读(132)  评论(0编辑  收藏  举报