HDU-6035 Colorful Tree(树形DP) 2017多校第一场

题意:给出一棵树,树上的每个节点都有一个颜色,定义一种值为两点之间路径中不同颜色的个数,然后一棵树有n*(n-1)/2条

路径,求所有的路径的值加起来是多少。

 

思路:比赛的时候感觉是树形DP,但是脑袋抽了,忘记树形DP是怎么遍历的了(其实没忘也不会做:)

先给出官方题解吧:

单独考虑每一种颜色,答案就是对于每种颜色至少经过一次这种的路径条数之和。反过来思考只需要求有多少条路径没有经过这种颜色即可。直接做可以采用虚树的思想(不用真正建出来),对每种颜色的点按照 dfs 序列排个序,就能求出这些点把原来的树划分成的块的大小。这个过程实际上可以直接一次 dfs 求出。

其实感觉就前面那两句能看懂= =,不过这个也才是主要的。很显然这题是算每个节点的贡献,但是每个节点都算的话就会算重,去重不容易,所以就可以计算

有哪些路径没有经过这种颜色,这个就很容易一点了,所以对于每种颜色就是求他的联通快的大小。

然后感觉还是不容易,可能是树形DP做少了吧,通过这个题还是学到不少的技巧 比如通过访问节点前,把前面的颜色给存起来,这样到时候访问的就是这颗子树的内容了,更新这些内容的时候一定要注意他们之间的关系

/** @xigua */
#include<stdio.h>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<vector>
#include<stack>
#include<cstring>
#include<queue>
#include<set>
#include<string>
#include<map>
#define PI acos(-1)
using namespace std;
typedef long long ll;
typedef double db;
const int maxn = 2e5 + 5;
const ll maxm = 1e7;
const int mod = 1e9 + 7 + 0.1;
const int INF = 1e9 + 7;
const ll inf = 1e15 + 5;
const db eps = 1e-9;
const int state = 15;
ll ans, col[maxn];
int head[maxn], cnt, n, siz[maxn], vis[maxn], c[maxn];
struct Edge {
    int v, next;
} e[maxn<<1];

void add(int u, int v) {
    e[cnt].v = v;
    e[cnt].next = head[u];
    head[u] = cnt++;
}

void init() {
    cnt = ans = 0;
    memset(head, -1, sizeof(head));
    memset(vis, 0, sizeof(vis));
    memset(col, 0, sizeof(col));
}

void dfs(int u, int fa) {
    siz[u] = 1;
    vis[c[u]] = 1;
    int pre = col[c[u]];    //访问节点前的
    int num = 0;    //因为一个节点有好几个儿子,所以就要把前面的给剔除
    for (int i = head[u]; ~i; i = e[i].next) {
        int v = e[i].v;
        if (fa == v) continue;
        dfs(v, u);
        siz[u] += siz[v];
        ll tmp = siz[v] - (col[c[u]] - pre - num);
        ans -= (tmp - 1) * tmp / 2;
        num = col[c[u]] - pre; //num就是当前颜色的个数减去以前的就是节点u的前面的儿子的
    }
    col[c[u]] = pre + siz[u];
}

void solve() {
    int cas = 1;

    while (cin >> n) {
        init();
        for (int i = 1; i <= n; i++)
            scanf("%d", c + i);
        for (int i = 1; i < n; i++) {
            int u, v; scanf("%d%d", &u, &v);
            add(u, v); add(v, u);
        }
        dfs(1, -1);
        for (int i = 1; i <= n; i++) {
            if (vis[i]) {
                ans += (ll)n * (n - 1) / 2;
                ll tmp = n - col[i];
                ans -= tmp * (tmp - 1) / 2;
            }
        }
        printf("Case #%d: %I64d\n", cas++, ans);
    }
}

int main() {
    int t = 1, cas = 1;
    //freopen("in.txt", "r", stdin);
   // scanf("%d", &t);
  //  init();
    while(t--) {
       // printf("Case #%d:\n", cas++);
        solve();
    }
    return 0;
}

  

posted @ 2017-07-28 16:26  ost_xg  阅读(228)  评论(0编辑  收藏  举报