HDU 6035 树形dp
Colorful Tree
Time Limit: 6000/3000 MS (Java/Others) Memory Limit: 131072/131072 K (Java/Others)
Total Submission(s): 1186 Accepted Submission(s): 474
Problem Description
There is a tree with n nodes, each of which has a type of color represented by an integer, where the color of node i is ci.
The path between each two different nodes is unique, of which we define the value as the number of different colors appearing in it.
Calculate the sum of values of all paths on the tree that has n(n−1)2 paths in total.
The path between each two different nodes is unique, of which we define the value as the number of different colors appearing in it.
Calculate the sum of values of all paths on the tree that has n(n−1)2 paths in total.
Input
The input contains multiple test cases.
For each test case, the first line contains one positive integers n, indicating the number of node. (2≤n≤200000)
Next line contains n integers where the i-th integer represents ci, the color of node i. (1≤ci≤n)
Each of the next n−1 lines contains two positive integers x,y (1≤x,y≤n,x≠y), meaning an edge between node x and node y.
It is guaranteed that these edges form a tree.
For each test case, the first line contains one positive integers n, indicating the number of node. (2≤n≤200000)
Next line contains n integers where the i-th integer represents ci, the color of node i. (1≤ci≤n)
Each of the next n−1 lines contains two positive integers x,y (1≤x,y≤n,x≠y), meaning an edge between node x and node y.
It is guaranteed that these edges form a tree.
Output
For each test case, output "Case #x: y" in one line (without quotes), where x indicates the case number starting from 1 and y denotes the answer of corresponding case.
Sample Input
3
1 2 1
1 2
2 3
6
1 2 1 3 2 1
1 2
1 3
2 4
2 5
3 6
Sample Output
Case #1: 6
Case #2: 29
Source
题意:
n个节点的数,每个都节点有颜色 ,规定一条路径的价值是路径上所有的点的不同的颜色个数,求所有路径(n*(n-1)/2)的价值和
代码:
首先这题肯定是算贡献,也就是计算出每种颜色参与了多少条路径,但这样正面考虑并不容易,不妨从反面考虑,计算每种颜色没有参与多少路径,然后拿 (路径总数 * 颜色总数) - 没参与的贡献,就是答案了。
对于一种颜色x,怎么计算没参与的路径数目呢,很显然,对于每个不包含颜色x的连通块中任意两点路径都是x不参与的贡献,那么问题就转化为,对于任意一种颜色x,需要求出每个不包含x的连通块大小。
这里利用了树形dp,对于结点u,若u的颜色为x,那么dfs(u)的过程中,我们就想知道对于u的每个儿子v构成的子树中最高的一批颜色也为x的结点是哪些,要是知道这些结点子树的大小,只要拿子树v的大小减去这些节点子树的大小,就可以得到包括v在内的没有颜色x的连通块大小。
举个例子,如图:
对于结点1,我们想求出与1相邻的且不是黑色的结点组成的连通块大小,此时就要dfs结点1的两个儿子2和3,若我们处理出来结点2中,最高的一批黑色结点(4,8)的所构成子树的大小分别为2和1,那么我们拿2为根子树的大小5,减去2,减去1,就得到了结点2不包含黑色结点的连通块大小是5-2-1=2,也就是图中的2和5组成的连通块。同理对3,可以求得连通块大小是2({3,6})。
计算出了子树u中连通块大小对答案的贡献之后,就要将当前最高一批的黑色结点更新为u,最高黑色节点构成的子树大小sum加上子树u中没有计算的那部分大小5({1,2,3,5,6})。
具体实现,看代码,sum[x]表示的遍历到当前位置,颜色为x的高度最高一批结点为根的子树大小总和。
对于一种颜色x,怎么计算没参与的路径数目呢,很显然,对于每个不包含颜色x的连通块中任意两点路径都是x不参与的贡献,那么问题就转化为,对于任意一种颜色x,需要求出每个不包含x的连通块大小。
这里利用了树形dp,对于结点u,若u的颜色为x,那么dfs(u)的过程中,我们就想知道对于u的每个儿子v构成的子树中最高的一批颜色也为x的结点是哪些,要是知道这些结点子树的大小,只要拿子树v的大小减去这些节点子树的大小,就可以得到包括v在内的没有颜色x的连通块大小。
举个例子,如图:
对于结点1,我们想求出与1相邻的且不是黑色的结点组成的连通块大小,此时就要dfs结点1的两个儿子2和3,若我们处理出来结点2中,最高的一批黑色结点(4,8)的所构成子树的大小分别为2和1,那么我们拿2为根子树的大小5,减去2,减去1,就得到了结点2不包含黑色结点的连通块大小是5-2-1=2,也就是图中的2和5组成的连通块。同理对3,可以求得连通块大小是2({3,6})。
计算出了子树u中连通块大小对答案的贡献之后,就要将当前最高一批的黑色结点更新为u,最高黑色节点构成的子树大小sum加上子树u中没有计算的那部分大小5({1,2,3,5,6})。
具体实现,看代码,sum[x]表示的遍历到当前位置,颜色为x的高度最高一批结点为根的子树大小总和。
#include<iostream> #include<cstdio> #include<cstring> using namespace std; typedef long long ll; const int maxn=200009; int tot,head[maxn],col[maxn]; bool vis[maxn]; ll ans,sz[maxn],sum[maxn]; struct Edge{ int to,next; }edge[maxn*2]; void add(int x,int y){ edge[tot].to=y; edge[tot].next=head[x]; head[x]=tot++; } void dfs(int u,int fa){ ll cntson=0; sz[u]=1; for(int i=head[u];i!=-1;i=edge[i].next){ int v=edge[i].to; if(v==fa) continue; ll last=sum[col[u]]; dfs(v,u); sz[u]+=sz[v]; ll cnt=sz[v]-(sum[col[u]]-last); ans+=cnt*(cnt-1)/2LL ; cntson+=cnt; } sum[col[u]]+=cntson+1; } int main() { int cas=0; ll n; while(scanf("%lld",&n)==1){ tot=0; memset(head,-1,sizeof(head)); memset(vis,0,sizeof(vis)); memset(sum,0,sizeof(sum)); int x,y; ll m=0; for(int i=1;i<=n;i++){ scanf("%d",&col[i]); if(!vis[col[i]]) m++; vis[col[i]]=1; } for(int i=1;i<n;i++){ scanf("%d%d",&x,&y); add(x,y); add(y,x); } ans=0; dfs(1,0); ans=m*n*(n-1)/2LL-ans; for(int i=1;i<=n;i++){ if(!vis[i]) continue; ll cnt=n-sum[i]; ans-=cnt*(cnt-1)/2LL; } printf("Case #%d: %lld\n",++cas,ans); } return 0; }