JLOI2014 洛谷p3158

#include <iostream>
#include <cstdio>
#include <string>
#define MN 3000050
using namespace std;
int dfn[MN], f[MN][21];
int n, m, cnt, root;
int head[MN], lg[MN];
int a[MN];
int sum[MN];
struct tu {
    int v, nxt;
} e[MN];
void add(int u, int v) {
    e[++cnt].v = v;
    e[cnt].nxt = head[u];
    head[u] = cnt;
}
void dfs(int now, int fa) {
    dfn[now] = dfn[fa] + 1;
    f[now][0] = fa;
    for (int i = 1; (1 << i) <= dfn[now]; i++) f[now][i] = f[f[now][i - 1]][i - 1];
    for (int i = head[now]; i; i = e[i].nxt) {
        if (e[i].v != fa)
            dfs(e[i].v, now);
    }
}
int lca(int x, int y) {
    if (dfn[x] < dfn[y])
        swap(x, y);
    while (dfn[x] > dfn[y]) {
        x = f[x][lg[dfn[x] - dfn[y]]];
    }
    if (x == y)
        return x;
    for (int k = lg[dfn[x]]; k >= 0; k--)
        if (f[x][k] != f[y][k])
            x = f[x][k], y = f[y][k];
    return f[x][0];
}
void search(int u) {
    for (int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].v;
        if (v == f[u][0])
            continue;
        search(v);
        sum[u] += sum[v];
    }
}
int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
    lg[0] = -1;
    for (int i = 1; i <= n; i++) lg[i] = lg[i / 2] + 1;
    for (int i = 1; i <= n - 1; i++) {
        int x, y;
        scanf("%d%d", &x, &y);
        add(x, y);
        add(y, x);
    }
    dfs(1, 0);
    for (int i = 1; i <= n - 1; i++) {
        int x = a[i];
        int y = a[i + 1];
        sum[x]++;
        sum[y]++;
        sum[lca(x, y)]--;
        sum[f[lca(x, y)][0]]--;
    }
    search(1);
    for (int i = 2; i <= n; i++) sum[a[i]]--;
    for (int i = 1; i <= n; i++) printf("%d\n", sum[i]);

    return 0;
}

 

posted @ 2019-06-04 10:25  红色OI再临  阅读(161)  评论(0编辑  收藏  举报