hdu6035 Colorful Tree 树形dp 给定一棵树,每个节点有一个颜色值。定义每条路径的值为经过的节点的不同颜色数。求所有路径的值和。

/**
题目:hdu6035 Colorful Tree
链接:http://acm.hdu.edu.cn/showproblem.php?pid=6035
题意:给定一棵树,每个节点有一个颜色值。定义每条路径的值为经过的节点的不同颜色数。求所有路径的值和。

思路:看题解后,才想出来的。树形dp。

求所有路径的值和 = 路径条数*总颜色数(n*(n-1)*colors/2)-sigma(每种颜色没有经过的路径条数)

主要是求每种颜色没有经过的路径条数。

画一棵树,我直接用颜色值表示节点编号。

             2
           /   \
          3     4
         /     /  \
        1     3    2
       / \   / \  / \
      4   5  4  5 3  5        12个点。

首先求颜色值为3的不经过的路径条数x
树上有三个3.很容易想到:
x = 最左边那个3下面的3个点构成的路径条数(3*2/2=3)+中间的3的两个子树分别构成的路径条数和(0)
+最右边的3的子树的分别构成的路径条数和(0)
+(总节点数-所有的3为根的子树节点数之和)*(总节点数-所有的以3为根的子树节点数之和-1)/2 ;

所以size[i]表示以i为根的树的节点数。

sum[i]在dfs过程中,,维护。。比如假设颜色为2.上图。 那么左子树是根为3,右子树是根为4.
那么递归完左子树之后,sum[2] = 0; 然后再递归完右子树后sum[2] = 3;就是右下角的那个2为根的子树的点数。

最终sum[i]表示所有以i颜色为根的子树的所有节点数之和。
sum[2] = 12;
sum[1] = 3;
sum[4] = 8;
sum[5] = 3;
sum[3] = 8;


*/



#include<iostream>
#include<cstdio>
#include<algorithm>
#include<vector>
#include<cstring>
using namespace std;
typedef long long LL;
const int N = 2e5+100;
int size[N];
int sum[N];
int col[N];
int vis[N];
int colors, n;
LL cnt;
vector<int> G[N];
void dfs(int r,int f)
{
    int len = G[r].size(), temp = 0;
    size[r]  = 1;
    if(sum[col[r]]!=0){
        temp = sum[col[r]];
        sum[col[r]] = 0;
    }
    for(int i = 0; i < len; i++){
        int to = G[r][i];
        if(to==f) continue;
        dfs(to,r);
        size[r] += size[to];
        cnt += (LL)(size[to]-sum[col[r]])*(size[to]-sum[col[r]]-1)/2;
        sum[col[r]] = 0;
    }
    sum[col[r]] = size[r]+temp;
}
int main()
{
    int cas = 1;
    while(scanf("%d",&n)==1)
    {
        memset(vis, 0, sizeof vis);
        memset(size, 0, sizeof size);
        memset(sum, 0, sizeof sum);
        colors = 0;
        for(int i = 1; i <= n; i++) G[i].clear();
        for(int i = 1; i <= n; i++){
            scanf("%d",&col[i]);
            if(vis[col[i]]==0){
                colors++;
            }
            vis[col[i]] = 1;
        }
        int u, v;
        for(int i = 1; i <= n-1; i++){
            scanf("%d%d",&u,&v);
            G[u].push_back(v);
            G[v].push_back(u);
        }
        cnt = 0;
        dfs(1,-1);
        for(int i = 1; i <= n; i++){
            if(vis[i]==0) continue;
            cnt += (LL)(n-sum[i])*(n-sum[i]-1)/2;
        }
        printf("Case #%d: %lld\n",cas++,(LL)n*(n-1)/2*colors-cnt);
    }
    return 0;
}

 

posted on 2017-07-26 14:53  hnust_accqx  阅读(484)  评论(0编辑  收藏  举报

导航