[ SPOJ Qtree2 ] Query on a tree II
\(\\\)
\(Description\)
给定一棵\(N\)个点的树,边有边权。作以下操作:
-
\(DIST\ a\ b\) 询问点 \(a\) 至点 \(b\) 路径上的边权之和
-
\(KTH\ a\ b\ k\) 询问点 \(a\) 至点 \(b\) 有向路径上的第 \(k\) 个点的编号
有 \(T\) 组测试数据,每组数据以 \(DONE\) 结尾。
- \(T\in [1,25],N\in [1,10000]\)
\(\\\)
\(Solution\)
\(LCA\) 裸题。
第一问可以 \(DFS\) 求以下树上前缀和,询问答案即为 \(dis[a]+dis[b]-dis[lca]\times 2\) 。
第二问用节点深度确定询问节点在哪一侧,若 \(deep[a]-deep[lca]\ge k\) 就在 \(a\) 到 \(lca\) 的路径上,反之在 \(b\) 到 \(lca\) 的路径上。
确定询问的节点是 \(a\) 或 \(b\) 的多少层祖先,利用倍增数组二进制拆分即可找到,处理时注意两端节点也会被计数。
\(\\\)
\(Code\)
#include<cmath>
#include<queue>
#include<cstdio>
#include<cctype>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
#define N 20010
#define R register
#define gc getchar
using namespace std;
char c;
int n,x,y,k,t,tot,d[N],hd[N],f[N][20],g[N];
struct edge{int w,to,nxt;}e[N<<1];
inline void add(int u,int v,int w){
e[++tot].to=v; e[tot].w=w;
e[tot].nxt=hd[u]; hd[u]=tot;
}
inline int rd(){
int x=0; bool f=0; char c=gc();
while(!isdigit(c)){if(c=='-')f=1;c=gc();}
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=gc();}
return f?-x:x;
}
queue<int> q;
inline void bfs(){
q.push(1); d[1]=1;
while(!q.empty()){
int u=q.front(); q.pop();
for(R int i=hd[u],v;i;i=e[i].nxt)
if(!d[v=e[i].to]){
d[v]=d[u]+1; f[v][0]=u;
for(R int j=1;j<=t;++j) f[v][j]=f[f[v][j-1]][j-1];
q.push(v);
}
}
}
void dfs(int u,int fa){
for(R int i=hd[u],v;i;i=e[i].nxt)
if((v=e[i].to)!=fa){
g[v]=g[u]+e[i].w; dfs(v,u);
}
}
inline int lca(int u,int v){
if(d[u]>d[v]) u^=v^=u^=v;
for(R int i=t;~i;--i) if(d[f[v][i]]>=d[u]) v=f[v][i];
if(u==v) return u;
for(R int i=t;~i;--i) if(f[u][i]!=f[v][i]) v=f[v][i],u=f[u][i];
return f[u][0];
}
inline void calc(int u,int v){
int fa=lca(u,v);
printf("%d\n",g[u]+g[v]-(g[fa]<<1));
}
inline void getkth(int u,int v,int k){
int fa=lca(u,v);
if(d[u]-d[fa]>=k){
for(R int i=t;~i;--i)
if(f[u][i]&&k>=(1<<i)) u=f[u][i],k-=(1<<i);
printf("%d\n",u);
}
else{
k=d[v]-d[fa]-(k-(d[u]-d[fa]));
for(R int i=t;~i;--i)
if(f[v][i]&&k>=(1<<i)) v=f[v][i],k-=(1<<i);
printf("%d\n",v);
}
}
inline void work(){
memset(f,0,sizeof(f));
memset(d,0,sizeof(d));
memset(g,0,sizeof(g));
memset(hd,0,sizeof(hd));
tot=0; t=log2((n=rd()))+1;
for(R int i=1,u,v,w;i<n;++i){
u=rd(); v=rd(); w=rd();
add(u,v,w); add(v,u,w);
}
bfs(); dfs(1,0);
while(1){
c=gc(); while(!isalpha(c)) c=gc();
if(c=='K'){x=rd(); y=rd(); k=rd(); getkth(x,y,k-1);}
else if(gc()=='I'){x=rd(); y=rd(); calc(x,y);}
else break;
}
}
int main(){
int ts=rd();
while(ts--) work();
return 0;
}