51nod1766-树上的最远点对【结论,线段树】
正题
题目链接:http://www.51nod.com/Challenge/Problem.html#problemId=1766
题目大意
给出\(n\)个点的一棵树,\(m\)次询问给出两个区间,要求在两个区间中各选一个点使得他们之间距离最大。
\(1\leq n,m\leq 10^5\)
解题思路
结论就是两个区间中选择的点都是在各自区间中距离最远的两个点中的一个。
证明:假设集合中最长的为\(len\),设两端的点为\(x,y\)。如果走到点\(x\)不是最远的那么走到\(y\)肯定是最远的,因为如果存在更远的点在集合内那么显然\(x\)走到那个点是比\(x\)走到\(y\)远的。
考虑如何快速询问区间的直径端点,对于两个相邻的区间\([l,mid],[mid+1,r]\),我们可以用上述的方法合并得到区间\([l,r]\)的答案。
这启示我们可以用线段树,用\(RMQ\)预处理树上距离即可。
时间复杂度:\(O(n\log n)\)
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#define mp(x,y) make_pair(x,y)
#define ll long long
using namespace std;
const ll N=2e5+10;
struct node{
ll to,next,w;
}a[N<<1];
ll n,q,tot,cnt,f[N][19],lg[N];
ll ls[N],dis[N],dep[N],rfn[N];
pair<ll,ll> w[N<<1];
void addl(ll x,ll y,ll w){
a[++tot].to=y;
a[tot].next=ls[x];
a[tot].w=w;ls[x]=tot;
return;
}
void dfs(ll x,ll fa){
f[++cnt][0]=x;rfn[x]=cnt;
dep[x]=dep[fa]+1;
for(ll i=ls[x];i;i=a[i].next){
ll y=a[i].to;
if(y==fa)continue;
dis[y]=dis[x]+a[i].w;
dfs(y,x);f[++cnt][0]=x;
}
return;
}
ll LCA(ll x,ll y){
ll l=rfn[x],r=rfn[y];
if(l>r)swap(l,r);
ll z=lg[r-l+1];
x=f[l][z];y=f[r-(1<<z)+1][z];
return (dep[x]<dep[y])?x:y;
}
ll gdis(ll x,ll y)
{return dis[x]+dis[y]-dis[LCA(x,y)]*2;}
pair<ll,ll> Merge(const pair<ll,ll> &a,const pair<ll,ll> &b){
ll p[4]={a.first,a.second,b.first,b.second};
pair<ll,ll> c=mp(p[0],p[1]);ll w=gdis(p[0],p[1]);
for(ll i=0;i<4;i++)
for(ll j=i+1;j<4;j++){
if(!i&&j==1)continue;
ll k=gdis(p[i],p[j]);
if(k>w)w=k,c=mp(p[i],p[j]);
}
return c;
}
void Build(ll x,ll L,ll R){
if(L==R){w[x]=mp(L,L);return;}
ll mid=(L+R)>>1;
Build(x*2,L,mid);
Build(x*2+1,mid+1,R);
w[x]=Merge(w[x*2],w[x*2+1]);
return;
}
pair<ll,ll> Ask(ll x,ll L,ll R,ll l,ll r){
if(L==l&&R==r)return w[x];
ll mid=(L+R)>>1;
if(r<=mid)return Ask(x*2,L,mid,l,r);
if(l>mid)return Ask(x*2+1,mid+1,R,l,r);
return Merge(Ask(x*2,L,mid,l,mid),Ask(x*2+1,mid+1,R,mid+1,r));
}
signed main()
{
scanf("%lld",&n);
for(ll i=1;i<n;i++){
ll x,y,w;
scanf("%lld%lld%lld",&x,&y,&w);
addl(x,y,w);addl(y,x,w);
}
dfs(1,0);
for(ll j=1;(1<<j)<=cnt;j++)
for(ll i=1;i+(1<<j)-1<=cnt;i++){
ll x=f[i][j-1],y=f[i+(1<<j-1)][j-1];
f[i][j]=(dep[x]<dep[y])?x:y;
}
for(ll i=2;i<=cnt;i++)lg[i]=lg[i>>1]+1;
Build(1,1,n);
scanf("%lld",&q);
while(q--){
ll l1,r1,l2,r2;
scanf("%lld%lld%lld%lld",&l1,&r1,&l2,&r2);
pair<ll,ll> a,b;
a=Ask(1,1,n,l1,r1);
b=Ask(1,1,n,l2,r2);
ll x=a.first,y=a.second;
ll l=b.first,r=b.second;
printf("%lld\n",max(max(gdis(x,l),gdis(x,r)),max(gdis(y,l),gdis(y,r))));
}
return 0;
}