[LCA]JZOJ 3717 火车

Description

A国有n个城市,城市之间有一些双向道路相连,并且城市两两之间有唯一路径。现在有火车在城市a,需要经过m个城市。火车按照以下规则行驶:每次行驶到还没有经过的城市中在m个城市中最靠前的。现在小A想知道火车经过这m个城市后所经过的道路数量。
 

Input

第一行三个整数n、m、a,表示城市数量、需要经过的城市数量,火车开始时所在位置。

接下来n-1行,每行两个整数x和y,表示x和y之间有一条双向道路。

接下来一行m个整数,表示需要经过的城市。

Output

一行一个整数,表示火车经过的道路数量。
 

Sample Input

5 4 2
1 2
2 3
3 4
4 5
4 3 1 5

Sample Output

9
 

Data Constraint

分析

显然是要跑LCA的

但是我们要考虑到LCA路径上经过了其他需要经过的节点需要抹去

那么考虑一个像树链剖分的Top的东西

表示这个点向上最早遇到的一个需要经过的节点

然后预处理好,每次求LCA就更新它

记得预处理用Bfs跑啊

#include <iostream>
#include <cstdio>
#include <queue>
using namespace std;
const int N=500001;
struct Dep {
    int x,last;
};
int top[N];
struct Edge {
    int u,v,nx;
}g[2*N];
int cnt,list[N];
int f[N][20],d[N],id[N];
int n,m,a[N];
long long ans;

void Add(int u,int v) {
    g[++cnt].u=u;g[cnt].v=v;g[cnt].nx=list[u];list[u]=cnt;
}

void Bfs() {
    queue<Dep> q;
    while (!q.empty()) q.pop();
    q.push((Dep){1,0});
    while (!q.empty()) {
        Dep u=q.front();q.pop();
        top[u.x]=u.last;
        if (id[u.x]) u.last=u.x;
        d[u.x]=d[f[u.x][0]]+1;
        for (int i=list[u.x];i;i=g[i].nx)
        if (g[i].v!=f[u.x][0]) f[g[i].v][0]=u.x,q.push((Dep){g[i].v,u.last});
    }
}

void Pre_LCA() {
    for (int k=1;k<=19;k++)
    for (int i=1;i<=n;i++)
    f[i][k]=f[f[i][k-1]][k-1];
}

int Get_LCA(int a,int b) {
    if (d[a]<d[b]) swap(a,b);
    for (int i=19;i>=0;i--)
    if (d[f[a][i]]>=d[b]) a=f[a][i],ans+=1<<i;
    if (a==b) return a;
    for (int i=19;i>=0;i--)
    if (f[a][i]!=f[b][i]) a=f[a][i],b=f[b][i],ans+=2<<i;
    ans+=2;
    return f[a][0];
}

void Jump(int u,int v) {
    int t,lca=Get_LCA(u,v);
    for (int i=u;i&&d[i]>=d[lca];t=i,i=top[i],top[t]=top[lca]) a[id[i]]=0;
    for (int i=v;i&&d[i]>=d[lca];t=i,i=top[i],top[t]=top[lca]) a[id[i]]=0;
}

int main() {
    freopen("train.in","r",stdin);
    freopen("train.out","w",stdout);
    scanf("%d%d%d",&n,&m,&a[1]);
    id[a[1]]=1;
    for (int i=1;i<n;i++) {
        int u,v;
        scanf("%d%d",&u,&v);
        Add(u,v);Add(v,u);
    }
    for (int i=2;i<=m+1;i++) {
        scanf("%d",&a[i]);
        id[a[i]]=i;
    }
    Bfs();
    Pre_LCA();
    int last=a[1],t;
    for (int i=2;i<=m+1;i++)
    if (a[i]) t=last,last=a[i],Jump(t,last);
    printf("%lld",ans);
    fclose(stdin);fclose(stdout);
}
View Code

 

posted @ 2018-08-19 21:30  Vagari  阅读(412)  评论(0编辑  收藏  举报