最远点对 [线段树+树的直径]
最远点对(线段树+树的直径)
题目
\(n\) 个点被 \(n-1\) 条边连接成了一颗树,给出 \([a,b]\) 和 \([c,d]\) 两个区间,表示点的标号请你求出两个区间内各选一点之间的最大距离,即你需要求出\(max\{dis(i,j)\ |\ a\leqslant i\leqslant b,c\leqslant j\leqslant d\}\)
(\(PS\): 建议使用读入优化)
输出格式
第一行一个数字 \(n\leqslant 100000\)。
第二行到第 \(n\) 行每行三个数字描述路的情况, \(x,y,z\) \((1\leqslant x,y\leqslant n,1\leqslant z\leqslant 10000)\)表示\(x\)和\(y\)之间有一条长度为\(z\)的路。
第\(n+1\)行一个数字\(m\),表示询问次数 \(m\leqslant 100000\)。
接下来\(m\)行,每行四个数\(a,b,c,d\)。
输出格式
共 \(m\) 行,表示每次询问的最远距离
样例
样例输入
5
1 2 1
2 3 2
1 4 3
4 5 4
1
2 3 4 5
样例输出
10
数据范围与提示
对于\(20\%\)的数据,保证\(n,m\leqslant 300\)。
对于另外\(20\%\)的数据,保证\(b-a,d-c\leqslant 100,m\leqslant 200\)。
对于另外\(20\%\)的数据,保证给出的树为一条链。
分析
一道超级好(恶心) 的 码农(思维)题。
题意就是给你一棵树,给你 \(T\) 对区间,找到两两区间中的两个点之间的最远距离。
暴力的话就是枚举两个区间中的点,然后找到最大的距离 \(dis\) ,这样你就会获得 \(TLE\ 0\) 分的好成绩。我们考虑一下如何维护所有区间中的最大距离,从而在每一次询问的时候可以直接查询到两个区间中最长的一段的两个端点和最长的端点之间的距离。
因为这里询问的时候给的是区间的形式,那么能够想到的就是用线段树来维护区间中的最值。而这里只需要找到最大值的两个端点即可,并不需要记录最值,那么我们的 \(push\ up\) 函数就需要一些改动。
在建树的时候,我们按照节点编号来进行建树,在到达叶子节点的时候,让当前的最大值的端点都是 \(l\) ,那么在 \(push\ up\) 的时候,我们就可以根据叶子节点之间的距离大小来进行对父亲节点的更新。
每个父亲节点的字节点都是存储着最大距离的左端点和右端点的,那么这里是有 \(6\) 种组合方式,我们根据这 \(6\) 种之中的最大距离的两个点来更新父亲的最大距离的左端点和右端点。进行六次比较即可,求距离可以用 \(LCA\) 来求解。
代码
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int maxn = 5e5+10;
int top[maxn],siz[maxn],son[maxn],fa[maxn];
int dis[maxn],dep[maxn];
struct Node{//L和R存储最大距离的u左端点和右端点。
int l,r,L,R;
}t[maxn];
struct N{//建边结构体
int v,next,val;
}e[maxn];
int head[maxn],tot;
inline int read(){//快读
int s = 0,f = 1;
char ch = getchar();
while(ch < '0' || ch > '9'){
if(ch == '-')f = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9'){
s = s * 10 + ch - '0';
ch = getchar();
}
return s * f;
}
inline void Add(int x,int y,int z){//建边
e[++tot].v = y;
e[tot].val = z;
e[tot].next = head[x];
head[x] = tot;
}
inline void dfs1(int x,int f){//树剖第一遍dfs
siz[x] = 1;
dep[x] = dep[fa[x]] + 1;
for(register int i=head[x];i;i=e[i].next){
int v = e[i].v;
if(v == f)continue;
fa[v] = x;
dis[v] = dis[x] + e[i].val;//求出每个节点距离根的距离
dfs1(v,x);
siz[x] += siz[v];
if(!son[x] || siz[son[x]] < siz[v]){
son[x] = v;
}
}
}
inline void dfs2(int x,int topf){//第二遍dfs
top[x] = topf;
if(son[x])dfs2(son[x],topf);
for(register int i=head[x];i;i=e[i].next){
int v = e[i].v;
if(v != fa[x] && v != son[x])dfs2(v,v);
}
}
inline int lca(int x,int y){//求lca
while(top[x] != top[y]){
if(dep[top[x]] < dep[top[y]])swap(x,y);
x = fa[top[x]];
}
return dep[x] < dep[y] ? x : y;
}
inline int getdis(int x,int y){//求两点之间的距离
int pre = lca(x,y);
return dis[x] + dis[y] - 2 * dis[pre];
}
inline void Merge(int a1,int a2,int b1,int b2,int &c1,int &c2){//将儿子节点的四个端点更新为父亲节点的最大距离的两个端点。
int ans = 0;
//下边记录的是六种端点情况
int a = getdis(a1,a2);
int b = getdis(a1,b1);
int c = getdis(a2,b1);
int d = getdis(a2,b2);
int e = getdis(a1,b2);
int f = getdis(b1,b2);
//以下依次找到最大距离的两个端点,通过取地址直接赋值
if(a > ans){ans = a;c1 = a1;c2 = a2;}
if(b > ans){ans = b;c1 = a1;c2 = b1;}
if(c > ans){ans = c;c1 = a2;c2 = b1;}
if(d > ans){ans = d;c1 = a2;c2 = b2;}
if(e > ans){ans = e;c1 = a1;c2 = b2;}
if(f > ans){ans = f;c1 = b1;c2 = b2;}
}
void pushup(int rt){//pushup函数更新父亲节点的最大距离左右端点
Merge(t[rt<<1].L,t[rt<<1].R,t[rt<<1|1].L,t[rt<<1|1].R,t[rt].L,t[rt].R);
}
void build(int rt,int l,int r){//建立线段树
t[rt].l = l;
t[rt].r = r;
if(l == r){t[rt].L = l;t[rt].R = l;return;}
int mid = (l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
pushup(rt);
}
void query(int rt,int l,int r,int &L,int &R){//查询区间中最大距离的左右端点,利用取地址直接赋值
if(t[rt].l >= l && t[rt].r <= r){
Merge(L,R,t[rt].L,t[rt].R,L,R);
return;
}
int mid = t[rt].l + t[rt].r >> 1;
if(l <= mid)query(rt<<1,l,r,L,R);
if(r > mid)query(rt<<1|1,l,r,L,R);
}
int main(){
int n = read(),m;
for(register int i=1;i<n;++i){//建边
int x = read(),y = read(),z = read();
Add(x,y,z);Add(y,x,z);
}
//树剖处理
dfs1(1,0);
dfs2(1,1);
build(1,1,n);//建立线段树
m = read();
while(m--){
int a = read(),b = read(),c = read(),d = read();
int e = a,f = b,g = c,h = d;//复制一下四个区间
query(1,a,b,e,f);//直接赋值为区间最大距离的左右端点
query(1,c,d,g,h);
int ans = 0;
//以下分别找到两个区间之间的最大距离
ans = max(ans,max(getdis(e,g),getdis(e,h)));
ans = max(ans,max(getdis(f,g),getdis(f,h)));
printf("%d\n",ans);//输出答案
}
return 0;
}