csu1811(树上启发式合并)

csu1811

题意

给定一棵树,每个节点有颜色,每次仅删掉第 \(i\) 条边 \((a_i, b_i)\) ,得到两颗树,问两颗树节点的颜色集合的交集。

分析

转化一下,即所求答案为每次删掉 \(u\)\(u\) 的父亲节点所连的边后形成的两颗子树的颜色集合的交集。
那么我们要求的其实和 \(u\) 的子树有关。子树的状态(颜色数量信息)是可以复用的。
可以套用 树上启发式合并 ,固定 1 为根节点,从上往下搜,每次保留子节点中节点最多的那颗子树的状态(颜色数量信息),也就是在计算当前节点所在子树的节点颜色的时候跳过这个子节点(因为前面保留了)。复杂度优化到 \(O(nlogn)\)

code

#include<cstdio>
#include<cstring>
#include<map>
#include<algorithm>
using namespace std;
typedef long long ll;
const int MAXN = 1e5 + 10;
int n;
int fa[MAXN], son[MAXN], dep[MAXN], siz[MAXN];
int col[MAXN];
int cnt, head[MAXN];
struct Edge {
    int to, next;
} e[MAXN << 1];
void addedge(int u, int v) {
    e[cnt].to = v;
    e[cnt].next = head[u];
    head[u] = cnt++;
}
void dfs(int u) {
    siz[u] = 1;
    son[u] = 0;
    for(int i = head[u]; ~i; i = e[i].next) {
        if(e[i].to != fa[u]) {
            fa[e[i].to] = u;
            dep[e[i].to] = dep[u] + 1;
            dfs(e[i].to);
            if(siz[e[i].to] > siz[son[u]]) son[u] = e[i].to;
            siz[u] += siz[e[i].to];
        }
    }
}
int COL[MAXN]; // 表示所有颜色的数量
int vis[MAXN];
int same; // 当前子树的颜色和另一个子树的颜色集合的交集
int C[MAXN]; // 当前子树某个颜色的数量
int ans[MAXN]; // ans[u]: 表示删掉 u 和它父亲所连的边后形成的两颗子树的答案
void update(int u, int c) {
    C[col[u]] += c;
    if(c > 0 && COL[col[u]] > 1) {
        if(C[col[u]] == 1) same++;
        else if(C[col[u]] == 0 || C[col[u]] == COL[col[u]]) same--;
    }
    for(int i = head[u]; ~i; i = e[i].next) {
        if(e[i].to != fa[u] && !vis[e[i].to]) update(e[i].to, c);
    }
}
void dfs1(int u, int flg) {
    for(int i = head[u]; ~i; i = e[i].next) {
        if(e[i].to != fa[u] && e[i].to != son[u]) dfs1(e[i].to, 1);
    }
    if(son[u]) {
        dfs1(son[u], 0);
        vis[son[u]] = 1;
    }
    update(u, 1);
    ans[u] = same;
    if(son[u]) vis[son[u]] = 0;
    if(flg) {
        update(u, -1);
        same = 0;
    }
}
typedef pair<int, int> P;
map<P, int> mp;
int res[MAXN];
int main() {
    while(~scanf("%d", &n)) {
        mp.clear();
        cnt = 0;
        dep[1] = 1;
        fa[1] = 1;
        memset(head, -1, sizeof head);
        memset(COL, 0, sizeof COL);
        memset(C, 0, sizeof C);
        memset(vis, 0, sizeof vis);
        same = 0;
        for(int i = 1; i <= n; i++) {
            scanf("%d", &col[i]);
            COL[col[i]]++;
        }
        for(int i = 1; i < n; i++) {
            int x, y;
            scanf("%d%d", &x, &y);
            mp[P(x, y)] = mp[P(y, x)] = i;
            addedge(x, y);
            addedge(y, x);
        }
        dfs(1);
        dfs1(1, -1);
        for(int i = 2; i <= n; i++) {
            res[mp[P(i, fa[i])]] = ans[i];
        }
        for(int i = 1; i < n; i++) {
            printf("%d\n", res[i]);
        }
    }
    return 0;
}
posted @ 2017-07-18 22:41  ftae  阅读(502)  评论(0编辑  收藏  举报