hdu 2586 LCA

思路:裸的LCA

#include<map>
#include<set>
#include<cmath>
#include<queue>
#include<cstdio>
#include<vector>
#include<string>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
#define pb push_back
#define mp make_pair
#define Maxn 40010
#define Maxm 80002
#define LL __int64
#define Abs(x) ((x)>0?(x):(-x))
#define lson(x) (x<<1)
#define rson(x) (x<<1|1)
#define inf 0x7fffffff
#define Mod 1000000007
using namespace std;
int head[Maxn],vi[Maxn],e,fs[Maxn],fa[Maxn],anc[Maxn],vis[Maxn],dis[Maxn],ans[Maxn];
struct Edge{
    int u,v,next,val;
}edge[Maxm];
vector< pair<int,int> > que[Maxn];
void init()
{
    memset(head,-1,sizeof(head));
    memset(vi,0,sizeof(vi));
    memset(vis,0,sizeof(vis));
    memset(fs,0,sizeof(fs));
    memset(dis,0,sizeof(dis));
    for(int i=0;i<Maxn;i++){
        fa[i]=i;
    }
    e=0;
}
void add(int u,int v,int val)
{
    edge[e].u=u,edge[e].v=v,edge[e].val=val,edge[e].next=head[u],head[u]=e++;
    edge[e].u=v,edge[e].v=u,edge[e].val=val,edge[e].next=head[v],head[v]=e++;
}
int find(int x)
{
    if(x!=fa[x])
        fa[x]=find(fa[x]);
    return fa[x];
}
void merg(int a,int b)
{
    int x=find(a);
    int y=find(b);
    if(fs[y]<=fs[x])
        fa[y]=x,fs[x]+=fs[y];
    else fa[x]=y,fs[y]+=fs[x];
}
void LCA(int u)
{
    int i,v,sz;
    sz=que[u].size();
    vi[u]=1;
    anc[u]=u;
    for(i=head[u];i!=-1;i=edge[i].next){
        v=edge[i].v;
        if(vi[v]) continue;
        LCA(v);
        merg(u,v);
        anc[find(u)]=u;
    }
    vis[u]=1;
    for(i=0;i<sz;i++){
        v=que[u][i].first;
        if(vis[v]){
            int lca=anc[find(v)];
            int x=que[u][i].second;
            ans[x]=dis[u]+dis[v]-2*dis[lca];
        }
    }
}
void dfs(int u)
{
    int i,v;
    vi[u]=1;
    for(i=head[u];i!=-1;i=edge[i].next){
        v=edge[i].v;
        if(vi[v]) continue;
        dis[v]=dis[u]+edge[i].val;
        dfs(v);
    }
}
int main()
{
    int t,n,m,i,j,u,v,val;
    scanf("%d",&t);
    while(t--){
        init();
        scanf("%d%d",&n,&m);
        for(i=1;i<n;i++){
            scanf("%d%d%d",&u,&v,&val);
            add(u,v,val);
        }
        for(i=1;i<=m;i++){
            scanf("%d%d",&u,&v);
            que[u].pb(mp(v,i));
            que[v].pb(mp(u,i));
        }
        dis[1]=0;
        dfs(1);
        memset(vi,0,sizeof(vi));
        LCA(1);
        for(i=1;i<=m;i++){
            printf("%d\n",ans[i]);
        }
    }
    return 0;
}

 

posted @ 2013-09-02 21:36  fangguo  阅读(189)  评论(0编辑  收藏  举报