[luogu4886] 快递员(点分治,树链剖分,lca)
dwq推的火题啊.
这题应该不算是点分治,但是用的点分治的思想.
每次找重心,算出每一对询问的答案找到答案最大值,考虑移动答案点,使得最大值减小.
由于这些点一定不能在u的两颗不同的子树里,否则你怎么移动都不会使得答案更优.
于是答案点就只会往一棵子树里移动.
移动答案点的时候用找重心来跳保证时间复杂度.
求lca可以用树剖.
注释很详细.
// luogu-judger-enable-o2
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#define maxn 100005
#define inf 1e9
using namespace std;
int n,m,rmx,size,ans,root,p;
int fa[maxn],siz[maxn],sz[maxn],top[maxn],dfn[maxn],dis[maxn];
int son[maxn],a[maxn],b[maxn],vis[maxn],dep[maxn],st[maxn];
int head[maxn],nxt[maxn<<1],to[maxn<<1],w[maxn<<1],cnt;
void add(int u,int v,int ww)
{
nxt[++cnt]=head[u];head[u]=cnt;
to[cnt]=v;w[cnt]=ww;
}
void getroot(int u,int ff)
{
sz[u]=1;int mx=0;
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(vis[v]||v==ff)continue;
getroot(v,u);sz[u]+=sz[v];
mx=max(mx,sz[v]);
}
mx=max(size-sz[u],mx);
if(mx<rmx)root=u,rmx=mx;
}
void dfs1(int u,int ff)
{
son[u]=0;fa[u]=ff;siz[u]=1;dep[u]=dep[ff]+1;
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];if(v==ff)continue;
dis[v]=dis[u]+w[i];
dfs1(v,u);siz[u]+=siz[v];
if(siz[son[u]]<siz[v])son[u]=v;
}
}
void dfs2(int u,int tf)
{
dfn[u]=++p;top[u]=tf;
if(!son[u])return;dfs2(son[u],tf);
for(int i=head[u];i;i=nxt[i])
if(to[i]!=fa[u]&&son[u]!=to[i])dfs2(to[i],to[i]);
}
int getlca(int x,int y)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
x=fa[top[x]];
}
return (dep[x]<dep[y])?x:y;
}
void solve(int u)
{
if(vis[u])return;vis[u]=1;
int mx=0,k,tot=0;
fa[u]=p=dis[u]=0;dfs1(u,0);dfs2(u,u);//树链剖分,记得初始化树剖的变量
for(int i=1;i<=m;i++)mx=max(mx,dis[a[i]]+dis[b[i]]);//找答案最大值
ans=min(ans,mx);//由于往下一个重心跳答案不一定最优,所以一路都要对答案取min
for(int i=1;i<=m;i++)//多个最大值存到数组中
if(mx==dis[a[i]]+dis[b[i]])
{
//cout<<a[i]<<" "<<b[i]<<" "<<getlca(a[i],b[i])<<endl;
if(getlca(a[i],b[i])==u)return;
//如果询问点对的两个点在不同子树,那么无论往哪个方向移动答案点,答案都不会更优
st[++tot]=dfn[a[i]];//a,b一定在一个子树内,只加入一个即可
}
for(int i=head[u],v;i;i=nxt[i])//找到第一对点询问所在的子树
if(st[1]>=dfn[v=to[i]]&&st[1]<dfn[v]+siz[v])k=v;
for(int i=2;i<=tot;i++)//如果不在同一子树,当前点就是答案,因为你不能往任意一棵子树靠近
if(st[i]<dfn[k]||st[i]>dfn[k]+siz[k]-1)return;
size=sz[k];rmx=sz[k]+1;getroot(k,0);solve(root);
}
int main()
{
cin>>n>>m;ans=inf;
for(int i=1,u,v,ww;i<n;i++)
scanf("%d%d%d",&u,&v,&ww),add(u,v,ww),add(v,u,ww);
for(int i=1;i<=m;i++)scanf("%d%d",&a[i],&b[i]);
size=n;rmx=n+1;getroot(1,0);solve(root);
printf("%d\n",ans);
return 0;
}