51nod 1766 树上的最远点对——线段树
n个点被n-1条边连接成了一颗树,给出a~b和c~d两个区间,表示点的标号请你求出两个区间内各选一点之间的最大距离,即你需要求出max{dis(i,j) |a<=i<=b,c<=j<=d}
(PS 建议使用读入优化)
Input
第一行一个数字 n n<=100000。
第二行到第n行每行三个数字描述路的情况, x,y,z (1<=x,y<=n,1<=z<=10000)表示x和y之间有一条长度为z的路。
第n+1行一个数字m,表示询问次数 m<=100000。
接下来m行,每行四个数a,b,c,d。
Output
共m行,表示每次询问的最远距离
Input示例
5
1 2 1
2 3 2
1 4 3
4 5 4
1
2 3 4 5
Output示例
10
————————————————————————————
这道题可以证明两个区间并起来的最远点对 一定是两个区间单独最远点对中的四个点
然后我们就可以利用线段树来维护辣
#include<cstdio> #include<cstring> #include<algorithm> using std::swap; const int M=2e5+7,inf=0x3f3f3f3f; int read(){ int ans=0,f=1,c=getchar(); while(c<'0'||c>'9'){if(c=='-') f=-1; c=getchar();} while(c>='0'&&c<='9'){ans=ans*10+(c-'0'); c=getchar();} return ans*f; } int max(int x,int y){return x>y?x:y;} int n,m; int first[M],cnt; struct node{int to,next,w;}e[2*M]; void ins(int a,int b,int w){e[++cnt]=(node){b,first[a],w}; first[a]=cnt;} void insert(int a,int b,int w){ins(a,b,w); ins(b,a,w);} int sz[M],son[M],dep[M],fa[M],top[M],id[M],idp=1,dis[M]; void f1(int x){ sz[x]=1; for(int i=first[x];i;i=e[i].next){ int now=e[i].to; if(now==fa[x]) continue; fa[now]=x; dep[now]=dep[x]+1; dis[now]=dis[x]+e[i].w; f1(now); sz[x]+=sz[now]; if(sz[now]>sz[son[x]]) son[x]=now; } } void f2(int x,int tp){ top[x]=tp; id[x]=idp++; if(son[x]) f2(son[x],tp); for(int i=first[x];i;i=e[i].next){ int now=e[i].to; if(now!=fa[x]&&now!=son[x]) f2(now,now); } } int cntq; struct pos{int mx,p1,p2;}tr[2*M+1007]; int find(int x,int y){ while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y); x=fa[top[x]]; } if(dep[x]>dep[y]) swap(x,y); return x; } int calc(int x,int y){ int lca=find(x,y); return dis[x]+dis[y]-2*dis[lca]; } void up(int x,int ls,int rs){ int k; tr[x].mx=tr[ls].mx; tr[x].p1=tr[ls].p1; tr[x].p2=tr[ls].p2; if(tr[rs].mx>=tr[x].mx) tr[x].mx=tr[rs].mx,tr[x].p1=tr[rs].p1,tr[x].p2=tr[rs].p2; int x1=tr[ls].p1,y1=tr[ls].p2,x2=tr[rs].p1,y2=tr[rs].p2; if(x1!=-1){ if(x2!=-1&&(k=calc(x1,x2))>=tr[x].mx) tr[x].mx=k,tr[x].p1=x1,tr[x].p2=x2; if(y2!=-1&&(k=calc(x1,y2))>=tr[x].mx) tr[x].mx=k,tr[x].p1=x1,tr[x].p2=y2; } if(y1!=-1){ if(x2!=-1&&(k=calc(y1,x2))>=tr[x].mx) tr[x].mx=k,tr[x].p1=x2,tr[x].p2=y1; if(y2!=-1&&(k=calc(y1,y2))>=tr[x].mx) tr[x].mx=k,tr[x].p1=y1,tr[x].p2=y2; } } void build(int x,int l,int r){ if(l==r){ tr[x].p1=l; tr[x].p2=l; tr[x].mx=0; return ; } int mid=(l+r)>>1; build(x<<1,l,mid); build(x<<1^1,mid+1,r); up(x,x<<1,x<<1^1); } int L,R; int push_ans(int x,int l,int r){ if(L<=l&&r<=R) return x; int mid=(l+r)>>1,ly=++cntq; tr[ly]=(pos){0,-1,-1}; int s1=0,s2=0; if(L<=mid) s1=push_ans(x<<1,l,mid); if(R>mid) s2=push_ans(x<<1^1,mid+1,r); up(ly,s1,s2); return ly; } int ans,ly,a,b,c,d,s1,s2; int main(){ int x,y,w; n=read(); tr[0].p1=tr[0].p2=-1; tr[0].mx=-inf; for(int i=1;i<n;i++) x=read(),y=read(),w=read(),insert(x,y,w); f1(1); f2(1,1); build(1,1,n); m=read(); for(int i=1;i<=m;i++){ a=read(); b=read(); c=read(); d=read(); cntq=2*M; L=a; R=b; int s1=push_ans(1,1,n); L=c; R=d; int s2=push_ans(1,1,n); int x1=tr[s1].p1,y1=tr[s1].p2; //printf("[%d %d]\n",x1,y1); int x2=tr[s2].p1,y2=tr[s2].p2; //printf("[%d %d]\n",x2,y2); ans=max(max(calc(x1,y2),calc(x1,x2)),max(calc(y1,y2),calc(y1,x2))); printf("%d\n",ans); } return 0; }