[题解]P9432 [NAPC-#1] rStage5 - Hard Conveyors
P9432 [NAPC-#1] rStage5 - Hard Conveyors
题意简述
给定一个\(N\)个节点的树形结构,其中有\(k\)个关键节点。
接下来有\(q\)次询问,每次询问给定\(x,y\),请输出\(x\)到\(y\)至少经过一个关键点的最短路径。
解题思路
我们发现,这道题相当于让我们从\(x\)到\(y\)的简单路径上,额外扩展出一路程去到达一个关键点。那么从哪里开始,到达哪一个关键点能消耗最少的步数呢?
我们不难想到,可以定义\(mindis[i]\)表示\(i\)到最近关键点的距离。这样答案就是:
先考虑\(mindis\)怎样计算。我们可以巧妙地用Dijkstra来求(只需要在堆优化Dijkstra的基础上稍作修改即可:初始化时,把所有关键点距离设为\(0\),并将它们入队列)。这一步骤时间复杂度\(O(n\log n)\)。
但我们还可以用\(O(n)\)的方法求出。\(mindis\)初始化为\(+\infty\),过程分为\(2\)次DFS。
- 第\(1\)次DFS:
- 如果\(u\)是关键点,那么\(mindis[u]=0\)。
- 给子节点\(i\)赋初值\(mindis[i]=mindis[u]+w(u,i)\)(\(w\)表示边权),并搜索子节点。
- 回溯时用子节点\(i\)的值更新\(u\):\(mindis[u]=\min(mindis[u],mindis[i]+w(u,i))\)。
- 第\(2\)次DFS:
- 第\(1\)次DFS可能会导致一些节点的\(mindis\)没有被更新,此时我们需要用它们父节点最终的\(mindis\)来更新子节点:即\(mindis[i]=\min(mindis[i],mindis[u]+w(u,i))\)。
- 搜索子节点,重复上述步骤。
\(mindis\)求出来了,但如果我们用\(O(n)\)的复杂度去遍历\(x\sim y\)路径上的所有节点,就会超时。
我们可以用倍增的思想来优化这一过程。用\(minn[i]\)来表示每个点往上\(2^i\)层的\(mindis\)。因为待会求\(dist(x,y)\)需要用LCA,所以\(minn\)就在预处理LCA的时候一块计算出来。
求\(dist(x,y)\)的话,用\(dis[u]\)表示\(u\)到根节点的距离(预处理LCA时计算出来),答案就是\(dis[x]+dis[y]-dis[lca(x,y)]\)。
怎么求\(\min\limits_{i是x\sim y路径上的节点}(mindis[i])\)呢?我们仍然用倍增的思想,让\(x,y\)往上跳,每跳一次用\(minn\)更新一下最小的\(mindis\),和求LCA的方法十分相似,所以就和求\(dist(x,y)\)合并成一个函数\(lca(x,y)\)了。返回值是一个pair
。first
是LCA,second
是最小的\(mindis\)。
用上面的式子得出结果即可,别忘了\(\times 2\)。
两种\(mindis\)求法的时间复杂度都是\(O((n+q)\log N)\),但显然DFS求法更高效。
实现细节:
- 如果用\(O(n)\)的方法求\(mindis\),别忘了距离数组初始化为\(+\infty\),但别太大,因为有\(+1\)操作。
Code
DFS求$mindis$
#include<bits/stdc++.h>
#define int long long
#define PII pair<int,int>
#define N 100010
using namespace std;
int n,q,k,dis[N],mindis[N];
//dis表示到根的距离
//mindis表示到关键点的最小距离
//minn维护每个点往上2^x层的mindis
int dep[N],fa[N][20],minn[N][20];
bool is[N],vis[N];
struct edge{int to,w;};
vector<edge> G[N];
void dfs1(int u,int father){
if(is[u]) mindis[u]=0;
for(auto i:G[u]){
if(i.to==father) continue;
mindis[i.to]=mindis[u]+i.w;
dfs1(i.to,u);
mindis[u]=min(mindis[u],mindis[i.to]+i.w);
}
}
void dfs2(int u,int father){
for(auto i:G[u]){
if(i.to==father) continue;
mindis[i.to]=min(mindis[i.to],mindis[u]+i.w);
dfs2(i.to,u);
}
}
void dfs(int u,int father){
dep[u]=dep[father]+1;
fa[u][0]=father;
for(int i=1;i<20;i++)
fa[u][i]=fa[fa[u][i-1]][i-1],
minn[u][i]=min(minn[u][i-1],minn[fa[u][i-1]][i-1]);
for(auto i:G[u])
if(i.to!=father)
minn[i.to][0]=min(mindis[i.to],mindis[u]),
dis[i.to]=dis[u]+i.w,
dfs(i.to,u);
}
pair<int,int> lca(int u,int v){
//first:LCA second:minn
if(u==v) return {u,mindis[u]};
int ans=INT_MAX;
if(dep[u]<dep[v]) swap(u,v);
for(int i=19;i>=0;i--)
if(dep[fa[u][i]]>=dep[v]){
ans=min(ans,minn[u][i]),u=fa[u][i];
}
if(u==v) return {u,ans};
for(int i=19;i>=0;i--)
if(fa[u][i]!=fa[v][i])
ans=min(ans,min(minn[u][i],minn[v][i])),
u=fa[u][i],v=fa[v][i];
ans=min(ans,min(minn[u][0],minn[v][0]));
return {fa[u][0],ans};
}
int solve(int u,int v){
auto t=lca(u,v);
return dis[u]+dis[v]-2*dis[t.first]+2*t.second;
}
signed main(){
memset(mindis,0x7f,sizeof mindis);
cin>>n>>q>>k;
for(int i=1;i<n;i++){
int u,v,w;
cin>>u>>v>>w;
G[u].push_back({v,w});
G[v].push_back({u,w});
}
for(int i=1;i<=k;i++){
int u;
cin>>u;
is[u]=1;
}
dfs1(1,0);
dfs2(1,0);
dfs(1,0);
while(q--){
int x,y;
cin>>x>>y;
cout<<solve(x,y)<<"\n";
}
return 0;
}
Dijkstra求$mindis$
#include<bits/stdc++.h>
#define int long long
#define PII pair<int,int>
#define N 100010
using namespace std;
int n,q,k,dis[N],mindis[N];
//dis表示到根的距离
//mindis表示到关键点的最小距离
//minn维护每个点往上2^x层的mindis
int dep[N],fa[N][20],minn[N][20];
bool is[N],vis[N];
struct edge{int to,w;};
vector<edge> G[N];
void dijkstra(){
priority_queue<PII,vector<PII>,greater<PII>> heap;
while(!heap.empty()) heap.pop();
for(int i=1;i<=n;i++){
if(!is[i]) mindis[i]=INT_MAX;
else heap.push({0,i});
}
while(!heap.empty()){
auto u=heap.top().second;
heap.pop();
if(vis[u]) continue;
vis[u]=1;
for(auto i:G[u]){
if(mindis[u]+i.w<mindis[i.to]){
mindis[i.to]=mindis[u]+i.w;
heap.push({mindis[i.to],i.to});
}
}
}
}
void dfs(int u,int father){
dep[u]=dep[father]+1;
fa[u][0]=father;
for(int i=1;i<20;i++)
fa[u][i]=fa[fa[u][i-1]][i-1],
minn[u][i]=min(minn[u][i-1],minn[fa[u][i-1]][i-1]);
for(auto i:G[u])
if(i.to!=father)
minn[i.to][0]=min(mindis[i.to],mindis[u]),
dis[i.to]=dis[u]+i.w,
dfs(i.to,u);
}
pair<int,int> lca(int u,int v){
//first:LCA second:minn
if(u==v) return {u,mindis[u]};
int ans=INT_MAX;
if(dep[u]<dep[v]) swap(u,v);
for(int i=19;i>=0;i--)
if(dep[fa[u][i]]>=dep[v]){
ans=min(ans,minn[u][i]),u=fa[u][i];
}
if(u==v) return {u,ans};
for(int i=19;i>=0;i--)
if(fa[u][i]!=fa[v][i])
ans=min(ans,min(minn[u][i],minn[v][i])),
u=fa[u][i],v=fa[v][i];
ans=min(ans,min(minn[u][0],minn[v][0]));
return {fa[u][0],ans};
}
int solve(int u,int v){
auto t=lca(u,v);
return dis[u]+dis[v]-2*dis[t.first]+2*t.second;
}
signed main(){
cin>>n>>q>>k;
for(int i=1;i<n;i++){
int u,v,w;
cin>>u>>v>>w;
G[u].push_back({v,w});
G[v].push_back({u,w});
}
for(int i=1;i<=k;i++){
int u;
cin>>u;
is[u]=1;
}
dijkstra();
dfs(1,0);
while(q--){
int x,y;
cin>>x>>y;
cout<<solve(x,y)<<"\n";
}
return 0;
}