HDU-6035 Colorful Tree(树形DP) 2017多校第一场
题意:给出一棵树,树上的每个节点都有一个颜色,定义一种值为两点之间路径中不同颜色的个数,然后一棵树有n*(n-1)/2条
路径,求所有的路径的值加起来是多少。
思路:比赛的时候感觉是树形DP,但是脑袋抽了,忘记树形DP是怎么遍历的了(其实没忘也不会做:)
先给出官方题解吧:
单独考虑每一种颜色,答案就是对于每种颜色至少经过一次这种的路径条数之和。反过来思考只需要求有多少条路径没有经过这种颜色即可。直接做可以采用虚树的思想(不用真正建出来),对每种颜色的点按照 dfs 序列排个序,就能求出这些点把原来的树划分成的块的大小。这个过程实际上可以直接一次 dfs 求出。
其实感觉就前面那两句能看懂= =,不过这个也才是主要的。很显然这题是算每个节点的贡献,但是每个节点都算的话就会算重,去重不容易,所以就可以计算
有哪些路径没有经过这种颜色,这个就很容易一点了,所以对于每种颜色就是求他的联通快的大小。
然后感觉还是不容易,可能是树形DP做少了吧,通过这个题还是学到不少的技巧 比如通过访问节点前,把前面的颜色给存起来,这样到时候访问的就是这颗子树的内容了,更新这些内容的时候一定要注意他们之间的关系
/** @xigua */ #include<stdio.h> #include<cmath> #include<iostream> #include<algorithm> #include<vector> #include<stack> #include<cstring> #include<queue> #include<set> #include<string> #include<map> #define PI acos(-1) using namespace std; typedef long long ll; typedef double db; const int maxn = 2e5 + 5; const ll maxm = 1e7; const int mod = 1e9 + 7 + 0.1; const int INF = 1e9 + 7; const ll inf = 1e15 + 5; const db eps = 1e-9; const int state = 15; ll ans, col[maxn]; int head[maxn], cnt, n, siz[maxn], vis[maxn], c[maxn]; struct Edge { int v, next; } e[maxn<<1]; void add(int u, int v) { e[cnt].v = v; e[cnt].next = head[u]; head[u] = cnt++; } void init() { cnt = ans = 0; memset(head, -1, sizeof(head)); memset(vis, 0, sizeof(vis)); memset(col, 0, sizeof(col)); } void dfs(int u, int fa) { siz[u] = 1; vis[c[u]] = 1; int pre = col[c[u]]; //访问节点前的 int num = 0; //因为一个节点有好几个儿子,所以就要把前面的给剔除 for (int i = head[u]; ~i; i = e[i].next) { int v = e[i].v; if (fa == v) continue; dfs(v, u); siz[u] += siz[v]; ll tmp = siz[v] - (col[c[u]] - pre - num); ans -= (tmp - 1) * tmp / 2; num = col[c[u]] - pre; //num就是当前颜色的个数减去以前的就是节点u的前面的儿子的 } col[c[u]] = pre + siz[u]; } void solve() { int cas = 1; while (cin >> n) { init(); for (int i = 1; i <= n; i++) scanf("%d", c + i); for (int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); add(u, v); add(v, u); } dfs(1, -1); for (int i = 1; i <= n; i++) { if (vis[i]) { ans += (ll)n * (n - 1) / 2; ll tmp = n - col[i]; ans -= tmp * (tmp - 1) / 2; } } printf("Case #%d: %I64d\n", cas++, ans); } } int main() { int t = 1, cas = 1; //freopen("in.txt", "r", stdin); // scanf("%d", &t); // init(); while(t--) { // printf("Case #%d:\n", cas++); solve(); } return 0; }