听风
题面:不存在的
我们先dfs一次,只保留每个分叉最上面的颜色,并求到根的前缀和,这就是每个点到根的和,记为s1,然后我们求子树里的,这里只保留了最上面的颜色,那么我们求树链的并,做树上前缀和,这就是子树颜色的和,记为s2,s1+s2就是子树和到根颜色的并,因为只保留了最上层的颜色,也就是说对于一个点,某种颜色只会存在在子树中或到根的路径上,不会重复统计,然后是风,我们先dfs一遍计算出访问序列,然后树上差分,如果这两个点是a和b,且a的访问时间早于b,那么我们让a=fa[a][0],因为size[u]-1,所以自然是第一个访问的不算,然后就是按访问顺序加入每个点的贡献,计算答案就行了。
这道题集合了很多树上方法,值得一做
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int N = 1e5 + 5; struct edge { int nxt, to; } e[N << 1]; int n, m, top, dfs_clock, cnt = 1, C; ll ans, sum; int mir[N], head[N], vis[N], in[N], out[N], a[N], fa[N][21], tim[N], val[N], st[N], dep[N]; vector<int> G[N], c[N]; ll s1[N], s2[N], s3[N]; bool cmp1(const int &i, const int &j) { return val[i] < val[j]; } bool cmp2(const int &i, const int &j) { return in[i] < in[j]; } void link(int u, int v) { e[++cnt].nxt = head[u]; head[u] = cnt; e[cnt].to = v; } void dfs(int u, int last) { in[u] = ++dfs_clock; mir[in[u]] = u; val[u] = u; if(!vis[a[u]]) vis[a[u]] = 1, ++s1[u]; else a[u] = 0; for(int i = head[u]; i; i = e[i].nxt) if(e[i].to != last) { G[u].push_back(e[i].to); fa[e[i].to][0] = u; dep[e[i].to] = dep[u] + 1; s1[e[i].to] += s1[u]; dfs(e[i].to, u); val[u] = min(val[u], val[e[i].to]); } sort(G[u].begin(), G[u].end(), cmp1); vis[a[u]] = 0; out[u] = dfs_clock; } int lca(int u, int v) { if(dep[u] < dep[v]) swap(u, v); int d = dep[u] - dep[v]; for(int i = 20; i >= 0; --i) if(d & (1 << i)) u = fa[u][i]; if(u == v) return u; for(int i = 20; i >= 0; --i) if(fa[u][i] != fa[v][i]) { u = fa[u][i]; v = fa[v][i]; } return fa[u][0]; } void dfs(int u) { if(a[u]) c[a[u]].push_back(u); for(int i = 0; i < G[u].size(); ++i) { int v = G[u][i]; dfs(v); } st[++top] = u; tim[u] = top; } int main() { scanf("%d%d%d", &n, &m, &C); for(int i = 1; i <= n; ++i) scanf("%d", &a[i]); for(int i = 1; i < n; ++i) { int u, v; scanf("%d%d", &u, &v); link(u, v); link(v, u); } dfs(1, 0); for(int j = 1; j <= 20; ++j) for(int i = 1; i <= n; ++i) fa[i][j] = fa[fa[i][j - 1]][j - 1]; dfs(1); for(int i = 1; i <= C; ++i) if(c[i].size()) { sort(c[i].begin(), c[i].end(), cmp2); ++s2[in[c[i][0]]]; for(int j = 1; j < c[i].size(); ++j) { int x = lca(c[i][j - 1], c[i][j]); --s2[in[x]]; ++s2[in[c[i][j]]]; } } for(int i = 1; i <= m; ++i) { int x, y, z, t; scanf("%d%d%d", &x, &y, &z); t = lca(x, y); if(tim[x] > tim[y]) swap(x, y); x = fa[x][0]; s3[in[x]] += z; s3[in[y]] += z; s3[in[t]] -= z; s3[in[fa[t][0]]] -= z; } for(int i = 1; i <= n; ++i) s2[i] += s2[i - 1], s3[i] += s3[i - 1]; for(int i = 1; i <= top; ++i) { // printf("st[%d] = %d\n", i, st[i]); ll d1 = s2[out[st[i]]] - s2[in[st[i]] - 1], d2 = s3[out[st[i]]] - s3[in[st[i]] - 1]; sum += s1[fa[st[i]][0]] + s2[out[st[i]]] - s2[in[st[i]] - 1] + s3[out[st[i]]] - s3[in[st[i]] - 1] - dep[st[i]]; ans = max(ans, sum); } printf("%lld\n", ans); return 0; }