【bzoj2870】最长道路tree 树的直径+并查集
题目描述
给定一棵N个点的树,求树上一条链使得链的长度乘链上所有点中的最小权值所得的积最大。
其中链长度定义为链上点的个数。
输入
第一行N
第二行N个数分别表示1~N的点权v[i]
接下来N-1行每行两个数x、y,表示一条连接x和y的边
输出
一个数,表示最大的痛苦程度。
样例输入
3
5 3 5
1 2
1 3
样例输出
10
题解
树的直径+并查集
首先肯定是把权值从大到小排序,按照顺序加点,维护每个连通块的最长链乘以当前点权值作为贡献。
那么如何在加上一条边,连接两棵树后快速得出新的直径呢?
一个结论:将两棵树连成一棵,新树的直径的两端点只有可能是原来两棵树两条直径四个端点中的某两个。
证明不太容易表述。。。简单画一画就差不多出来了。实在不行可以先推加一个点的情况,然后再推加一棵树。
于是使用并查集维护树的直径长度及端点位置,使用倍增LCA求距离,就做完了。。。
注意需要开long long。
时间复杂度 $O(n\log n)$
#include <cstdio> #include <algorithm> #define N 50010 using namespace std; typedef long long ll; int v[N] , id[N] , head[N] , to[N << 1] , next[N << 1] , cnt , fa[N][17] , deep[N] , log[N] , f[N] , px[N] , py[N]; ll ans = 0; bool cmp(int a , int b) { return v[a] > v[b]; } inline void add(int x , int y) { to[++cnt] = y , next[cnt] = head[x] , head[x] = cnt; } void dfs(int x) { int i; for(i = 1 ; (1 << i) <= deep[x] ; i ++ ) fa[x][i] = fa[fa[x][i - 1]][i - 1]; for(i = head[x] ; i ; i = next[i]) if(to[i] != fa[x][0]) fa[to[i]][0] = x , deep[to[i]] = deep[x] + 1 , dfs(to[i]); } inline int lca(int x , int y) { int i; if(deep[x] < deep[y]) swap(x , y); for(i = log[deep[x] - deep[y]] ; ~i ; i -- ) if(deep[x] - deep[y] >= (1 << i)) x = fa[x][i]; if(x == y) return x; for(i = log[deep[x]] ; ~i ; i -- ) if(deep[x] >= (1 << i) && fa[x][i] != fa[y][i]) x = fa[x][i] , y = fa[y][i]; return fa[x][0]; } inline int dis(int x , int y) { return deep[x] + deep[y] - (deep[lca(x , y)] << 1); } int find(int x) { return x == f[x] ? x : f[x] = find(f[x]); } void solve(int x) { int i , tx , ty , t , vm , vx , vy; for(i = head[x] ; i ; i = next[i]) { if(f[to[i]]) { tx = find(x) , ty = find(to[i]) , vm = -1; if(vm < (t = dis(px[tx] , py[tx]))) vm = t , vx = px[tx] , vy = py[tx]; if(vm < (t = dis(px[ty] , py[ty]))) vm = t , vx = px[ty] , vy = py[ty]; if(vm < (t = dis(px[tx] , px[ty]))) vm = t , vx = px[tx] , vy = px[ty]; if(vm < (t = dis(px[tx] , py[ty]))) vm = t , vx = px[tx] , vy = py[ty]; if(vm < (t = dis(py[tx] , px[ty]))) vm = t , vx = py[tx] , vy = px[ty]; if(vm < (t = dis(py[tx] , py[ty]))) vm = t , vx = py[tx] , vy = py[ty]; f[ty] = tx , px[tx] = vx , py[tx] = vy; } } tx = find(x) , ans = max(ans , (ll)v[x] * (dis(px[tx] , py[tx]) + 1)); } int main() { int n , i , x , y; scanf("%d" , &n); for(i = 1 ; i <= n ; i ++ ) scanf("%d" , &v[i]) , id[i] = i; for(i = 2 ; i <= n ; i ++ ) scanf("%d%d" , &x , &y) , add(x , y) , add(y , x) , log[i] = log[i >> 1] + 1; dfs(1); sort(id + 1 , id + n + 1 , cmp); for(i = 1 ; i <= n ; i ++ ) f[id[i]] = px[id[i]] = py[id[i]] = id[i] , solve(id[i]); printf("%lld\n" , ans); return 0; }