BZOJ2870 最长道路tree
BZOJ2870 最长道路tree
题面:又是权限题。。。
解析
这种含多个参数(其中有有关最值的参数)的最优值求解都是套路啊。大概就是通过排序来消除一个参数的影响,然后维护答案。那么这道题就是按权值从大到小排序依次加入,然后当前权值乘上当前最长链即可。如何维护当前最长链呢?发现当两棵树合并后,新树直径的两个端点一定是原来两个直径的四个端点中的两个,然后就可以用并查集维护了。证明用反证法即可。
代码
#include<cmath>
#include<cstdio>
#include<iostream>
#include<algorithm>
#define LL long long
#define N 50005
using namespace std;
const int nlog=16;
inline int In(){
char c=getchar(); int x=0,ft=1;
for(;c<'0'||c>'9';c=getchar()) if(c=='-') ft=-1;
for(;c>='0'&&c<='9';c=getchar()) x=x*10+c-'0';
return x*ft;
}
int n,val[N],id[N],t[10];
bool vis[N];
inline bool cmp(int a,int b){
return val[a]>val[b];
}
int h[N],e_tot=0;
struct E{ int to,nex; }e[N<<1];
inline void add(int u,int v){
e[++e_tot]=(E){v,h[u]}; h[u]=e_tot;
e[++e_tot]=(E){u,h[v]}; h[v]=e_tot;
}
int d[N],p[N*2][nlog+5],c[N],dfs_clock=0;
void dfs(int u,int pre,int dep){
d[u]=dep; p[++dfs_clock][0]=u; c[u]=dfs_clock;
for(int i=h[u],v;i;i=e[i].nex){
v=e[i].to; if(v==pre) continue;
dfs(v,u,dep+1); p[++dfs_clock][0]=u;
}
}
int po[nlog+5];
inline void get_st(){
po[0]=1; for(int i=1;i<=nlog;++i) po[i]=po[i-1]*2;
for(int j=1;j<=nlog;++j)
for(int i=1;i+po[j]<=dfs_clock;++i)
p[i][j]=d[p[i][j-1]]<d[p[i+po[j-1]][j-1]]
?p[i][j-1]:p[i+po[j-1]][j-1];
}
inline int LCA(int x,int y){
if(c[x]>c[y]) swap(x,y);
int k=log(c[y]-c[x]+1)/log(2);
return d[p[c[x]][k]]<d[p[c[y]-po[k]+1][k]]
?p[c[x]][k]:p[c[y]-po[k]+1][k];
}
inline int dist(int x,int y){
return d[x]+d[y]-2*d[LCA(x,y)]+1;
}
int mx_d; LL ans=0;
int f[N],usz[N],_p[N][2],_w[N];
inline int find(int x){
return f[x]=(f[x]==x?x:find(f[x]));
}
inline void unit(int x,int y){
int a=find(x),b=find(y);
if(usz[a]<usz[b]) swap(a,b);
usz[a]+=usz[b]; f[b]=a;
t[1]=_p[a][0]; t[2]=_p[a][1];
t[3]=_p[b][0]; t[4]=_p[b][1];
for(int i=1;i<4;++i)
for(int j=i+1;j<=4;++j){
int dis=dist(t[i],t[j]);
if(dis>_w[a]){
_w[a]=dis;
_p[a][0]=t[i];
_p[a][1]=t[j];
}
}
mx_d=max(mx_d,_w[a]);
}
int main(){
n=In(); mx_d=1;
for(int i=1;i<=n;++i){
val[i]=In(); id[i]=i;
f[i]=i; usz[i]=1;
_p[i][0]=_p[i][1]=i; _w[i]=1;
ans=max(ans,(LL)val[i]);
}
sort(id+1,id+1+n,cmp);
for(int i=1;i<n;++i) add(In(),In());
dfs(1,0,1); get_st();
for(int j=1,u;j<=n;++j){
u=id[j]; vis[u]=1;
for(int i=h[u],v;i;i=e[i].nex){
v=e[i].to; if(vis[v]==1) unit(u,v);
}
ans=max(ans,(LL)val[u]*mx_d);
}
printf("%lld\n",ans);
return 0;
}