HDU 5923 Prediction

这题是2016 CCPC 东北四省赛的B题, 其实很简单. 现场想到的就是正解, 只是在合并两个并查集这个问题上没想清楚.

做法

并查集合并 + 归并

  1. 对每个节点 \(u\), 将 \(u\) 到根的那些边添到一个初始为空的并查集中, 得到的并查集记作 \(a_u\).
  2. 询问相当于将 \(k\) 个并查集合并. 采用二路归并, 合并次数是 \(O(n \cdot \log(n))\).
    $ n/2 + n/4 + n/8 + \dots + 1 = O(n \cdot \log(n)) $

合并两个并查集

详细讨论将并查集 \(B\) 合并到并查集 \(A\) 中这一问题.
这个问题与

给定两无向图 $A, B, V_B \subset V_A; \quad A(E_A, V_A) \to A'( E_A, E_A \cup E_B) $.

等价.

做法

$ \forall u \in E_B, \quad A.\mathrm{unite}(u, B.\mathrm{root}(u)) $

正确性

只要验证

\(B\)中连通的任意两点 \(u, v\), 在$ A'$中也连通.

是否满足.

Implementation

#include <bits/stdc++.h>
using namespace std;

const int N{1<<9};
const int M=1e4+5;

int n, m;

struct DSU{
    int par[N];
    int cnt;

    int find(int x){
        return par[x]==x?x: par[x]=find(par[x]);
    }

    void unite(int x, int y){
        x=find(x);
        y=find(y);
        if(x!=y){
            par[x]=y;
            --cnt;
        }
    }

    void unite(DSU &a){
        for(int i=1; i<=n; i++){
            unite(find(i), a.find(i));  // ?
        }
    }

    void init(){
        for(int i=1; i<=n; i++){
            par[i]=i;
        }
        cnt=n;
    }

    void copy(const DSU &a){
        for(int i=1; i<=n; i++){
            par[i]=a.par[i];
        }
        cnt=a.cnt;
    }
};

DSU a[M], b[M];

vector<int> g[M];

struct Edge{
    int u, v;
    void read(){
        scanf("%d%d", &u, &v);
    }
}E[M];

void dfs(int u, int f){
    a[u].copy(a[f]);
    a[u].unite(E[u].u, E[u].v);

    for(auto v: g[u]){
        dfs(v, u);
    }
}



void solve(int n){
    for(int i=1; i<n; i<<=1){   // error-prone
        for(int j=0; j+i<n; j+=i<<1){
            b[j].unite(b[j+i]);
        }
    }
    printf("%d\n", b[0].cnt);
}

// int par[M];

int main(){

    int T, cas{};
    for(cin>>T; T--; ){
        printf("Case #%d:\n", ++cas);
        // int n, m;
        cin>>n>>m;

        for(int i=1; i<=m; ++i){
            g[i].clear();
        }

        for(int i=2; i<=m; i++){
            // scanf("%d", par+i);
            int fa;
            scanf("%d", &fa);
            g[fa].push_back(i);
        }

        for(int i=1; i<=m; ++i){
            E[i].read();
        }

        a[0].init();
        dfs(1, 0);

        int q;
        cin>>q;
        for(; q--; ){
            int k;
            scanf("%d", &k);
            for(int i=0; i<k; i++){
                int x;
                scanf("%d", &x);
                b[i].copy(a[x]);
            }
            solve(k);
        }
    }
    return 0;
}

Pitfalls

归并

for(int i=1; i<n; i<<=1){   // error-prone
    for(int j=0; j+i<n; j+=i<<1){
        b[j].unite(b[j+i]);
    }
}

容易写错.

我第一发是这样写的

for(int i=2; i<=n; i<<=1){
    for(int j=0; j+i/2<n; j+=i){
        b[j].unite(b[j+i/2]);
    }
}

n==3时, 只做了1轮归并.

应采纳第一种写法, 很清楚.


UPD
太SB了.

  1. 根本不用归并, 直接逐个合并就好了.
  2. 根本不用 b[i].copy(a[x]); , 只要从一个边集为空的图 (以下简称"空图") 开始, 不断把\(k\)个并查集合并进去就好了.
  3. 不从空图开始, 而从某个并查集开始, 会快很多.
#include <bits/stdc++.h>
using namespace std;

const int N{1<<9};
const int M=1e4+5;

int n, m;

struct DSU{
    int par[N];
    int cnt;

    int find(int x){
        return par[x]==x?x: par[x]=find(par[x]);
    }

    void unite(int x, int y){
        x=find(x);
        y=find(y);
        if(x!=y){
            par[x]=y;
            --cnt;
        }
    }

    void unite(DSU &a){
        for(int i=1; i<=n; i++){
            unite(find(i), a.find(i));  // ?
        }
    }

    void init(){
        for(int i=1; i<=n; i++){
            par[i]=i;
        }
        cnt=n;
    }

    void copy(const DSU &a){
        for(int i=1; i<=n; i++){
            par[i]=a.par[i];
        }
        cnt=a.cnt;
    }
};

DSU a[M], b[M];

vector<int> g[M];

struct Edge{
    int u, v;
    void read(){
        scanf("%d%d", &u, &v);
    }
}E[M];

void dfs(int u, int f){
    a[u].copy(a[f]);
    a[u].unite(E[u].u, E[u].v);

    for(auto v: g[u]){
        dfs(v, u);
    }
}



int solve(int n){
    if(k==0){
        return n;
    }
    int x;
    scanf("%d", &x);
    a[0].copy(a[x]);
    for(int i=1; i<n; i++){
        scanf("%d", &x);
        a[0].unite(a[x]);
    }
    return a[0].cnt;
}

int main(){

    int T, cas{};
    for(cin>>T; T--; ){
        printf("Case #%d:\n", ++cas);

        cin>>n>>m;

        for(int i=1; i<=m; ++i){
            g[i].clear();
        }

        for(int i=2; i<=m; i++){
            // scanf("%d", par+i);
            int fa;
            scanf("%d", &fa);
            g[fa].push_back(i);
        }

        for(int i=1; i<=m; ++i){
            E[i].read();
        }

        a[0].init();
        dfs(1, 0);

        int q;
        cin>>q;
        for(; q--; ){
            int k;
            scanf("%d", &k);        
            printf("%d\n", solve(k));
        }
    }
    return 0;
}
posted @ 2016-10-17 22:02  Pat  阅读(461)  评论(0编辑  收藏  举报