POJ-3107 Godfather 求每个节点连接的联通块数量

dp[n][2],维护儿子的联通块数量和父亲的联通块数量。

第一遍dfs求儿子,第二遍dfs求爸爸。

#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<algorithm>
#include <string>
#include<queue>
#include<vector>
#include<set>
#include<map>
#include <iomanip>
#define LL long long
#define INF 20000000
#define N 50500
using namespace std;
int n, m, sum, cnt, flag;
int deg[N];
int head[N];
struct Node
{
    int v, next;
};

Node edge[N << 1];
LL dis[N];
LL dp[N][2];//0 fat,1 son
LL mx[N];
void dfs1(int now, int pre)
{
    dp[now][0] = 1;
    for (int i = head[now]; i != -1; i = edge[i].next)
    {
        int e = edge[i].v;
        if (e == pre) continue;
        dfs1(e, now);
        dp[now][0] += dp[e][0];
    }
}
void dfs2(int now, int pre)
{
    if (pre == -1)
    {
        dp[now][1] = mx[now] = 0;
    }
    else
    {
        dp[now][1] = dp[pre][0] + dp[pre][1] - dp[now][0];
        mx[now] = dp[pre][0] + dp[pre][1] - dp[now][0];
    }
    for (int i = head[now]; i != -1; i = edge[i].next)
    {
        int e = edge[i].v;
        if (e == pre) continue;
        dfs2(e, now);
        mx[now] = max(mx[now], dp[e][0]);
    }
}
void ini()
{
    for (int i = 1; i <= n; i++)
         head[i] = -1, flag = 0;
    cnt = 0, sum = n;
}
void add(int u, int v)
{
    deg[v]++;
    edge[cnt].v = v;
    edge[cnt].next = head[u];
    head[u] = cnt++;
}
int  main()
{
    //cin.sync_with_stdio(false);
    while (scanf("%d",&n)!=EOF)
    {
        ini();
        for (int i = 2; i <= n; i++)
        {
            LL a, b;
            //cin >> a >> b;
            scanf("%d%d", &a, &b);
            add(a, b);
            add(b, a);
        }
        memset(dp, 0, sizeof(dp));
        LL ans = INF;
        dfs1(1, -1);
        dfs2(1, -1);
        for (int i = 1; i <= n; i++)
            ans = min(ans, mx[i]);
        
        
        set<int> ss;
        for (int i = 1; i <= n; i++)
            if (mx[i] == ans) ss.insert(i);
        for (set<int>::iterator i = ss.begin(); i != ss.end();)
        {
            //cout << *i;
            printf("%d", *i);
            i++;
            if (i == ss.end())printf("\n");
            else printf(" ");
        }
    }
    return 0;
}

 

posted @ 2017-08-18 19:32  Luke_Ye  阅读(217)  评论(0编辑  收藏  举报