树上游戏

题目描述

lrb有一棵树,树的每个节点有个颜色。给一个长度为n的颜色序列,定义s(i,j) 为i 到j 的颜色数量

\(Sum_i = \sum_{i=1}^{n}{s_{i,j}}\)

现在他想让你求出所有的sum[i]

输入输出格式

输入格式:

第一行为一个整数n,表示树节点的数量

第二行为n个整数,分别表示n个节点的颜色c[1],c[2]……c[n]

接下来n-1行,每行为两个整数x,y,表示x和y之间有一条边

输出格式:

输出n行,第i行为sum[i]

输入输出样例

输入样例#1:

5
1 2 3 2 3
1 2
2 3
2 4
1 5

输出样例#1:

10
9
11
9
12

说明

sum[1]=s(1,1)+s(1,2)+s(1,3)+s(1,4)+s(1,5)=1+2+3+2+2=10
sum[2]=s(2,1)+s(2,2)+s(2,3)+s(2,4)+s(2,5)=2+1+2+1+3=9
sum[3]=s(3,1)+s(3,2)+s(3,3)+s(3,4)+s(3,5)=3+2+1+2+3=11
sum[4]=s(4,1)+s(4,2)+s(4,3)+s(4,4)+s(4,5)=2+1+2+1+3=9
sum[5]=s(5,1)+s(5,2)+s(5,3)+s(5,4)+s(5,5)=2+3+3+3+1=12
对于40%的数据,n<=2000

对于100%的数据,1<=n,c[i]<=10^5


题解

点分治

一堆细节

要算每两个点对之间颜色种类的个数

\(O(n^2)\)很好做,只需要枚举每个点当根,当某种颜色第一次出现时对根的答案贡献+Size[u]

考虑怎么优化这个过程

可以使用点分治来处理每个点对答案的贡献来做到不重不漏

因为点分治每次处理的是经过分治重心的路径的贡献

所以我们先处理出从分治重心开始到各个子树的每个颜色的路径条数\(val[col[u]]\)

然后对于每个子树

先减掉该子树对\(val[]\)的贡献

然后计算出sum_val\(= \sum_{i=1}^{i<=colnum}{val[i]}\)

对于从u到根(不包含根)的路径上的出现过的颜色,把他们从sum_val中减掉

再加上当前分治块的大小减掉当前分治子树的大小

表示这种颜色对答案的贡献是跨过分治重心的所有路径

最后清空一下数组就好辣

代码

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
# define int long long
const int M = 100005 ;
const int INF = 1e9 + 7 ;
using namespace std ;
inline int read() {
	char c = getchar() ; int x = 0 , w = 1 ;
	while(c>'9'||c<'0') { if(c=='-') w = -1 ; c = getchar() ; }
	while(c>='0'&&c<='9') { x = x*10+c-'0' ; c = getchar() ; }
	return x*w ;
}

bool vis[M] , hap[M] ;
int n , col[M] , hea[M] , num ;
int tot , rt , tmin , size[M] ;
int appear[M] , val[M] ;
int Sum[M] , Bsz , sum_val ;
int colnum , tcl[M] , Tag ;


struct E { int Nxt , to ; } edge[M << 1] ;
inline void add_edge(int from , int to) { 
	edge[++num].Nxt = hea[from] ; 
	edge[num].to = to ;hea[from] = num ; 
}
void Getroot(int u , int father) {
	size[u] = 1 ; int Mx = -1 ; 
	for(int i = hea[u] ; i ; i = edge[i].Nxt) {
		int v = edge[i].to ; if(v == father || vis[v]) continue ;
		Getroot(v , u) ; size[u] += size[v] ; Mx = max(Mx , size[v]) ;
	}
	Mx = max(Mx , tot - size[u]) ; if(Mx < tmin) tmin = Mx , rt = u ;
}
void FirDfs(int u , int father) {
	if(!hap[col[u]]) {
		tcl[++colnum] = col[u] ;
		hap[col[u]] = true ;
	}
	size[u] = 1 ; appear[col[u]] ++ ;
	for(int i = hea[u] ; i ; i = edge[i].Nxt) {
		int v = edge[i].to ; if(v == father || vis[v]) continue ;
		FirDfs(v , u) ; size[u] += size[v] ;
	}
	appear[col[u]] -- ;
	if(!appear[col[u]])  val[col[u]] += size[u] ;
}
void Update(int u , int father , int dlt) {
	appear[col[u]] ++ ;
	for(int i = hea[u] ; i ; i = edge[i].Nxt) {
		int v = edge[i].to ; if(v == father || vis[v]) continue ;
		Update(v , u , dlt) ; 
	}
	appear[col[u]] -- ; if(!appear[col[u]]) val[col[u]] += size[u] * dlt ;
}
void GetAns(int u , int father) {
	if(!appear[col[u]]) {
		Tag -= val[col[u]] ;
		Tag += Bsz ;
	}
	appear[col[u]] ++ ;
	Sum[u] += sum_val + Tag ;
	for(int i = hea[u] ; i ; i = edge[i].Nxt) {
		int v = edge[i].to ; if(vis[v] || v == father) continue ;
		GetAns(v , u) ;
	}
	appear[col[u]] -- ;
	if(!appear[col[u]]) {
		Tag += val[col[u]] ;
		Tag -= Bsz ;
	}
}
void Dfs(int u) {
	FirDfs(u , u) ; vis[u] = true ;
	sum_val = 0 ;
	for(int i = 1 ; i <= colnum ; i ++) 
		sum_val += val[tcl[i]] ;
	Sum[u] += sum_val ;
	for(int i = hea[u] ; i ; i = edge[i].Nxt) {
		int v = edge[i].to ; if(vis[v]) continue ;
		appear[col[u]] = 1 ; val[col[u]] -= size[v] ;
		Update(v , u , -1) ; 
		Bsz = size[u] - size[v] ; sum_val = 0 ; 
		for(int j = 1 ; j <= colnum ; j ++) sum_val += val[tcl[j]] ;
		appear[col[u]] = 0 ; Tag = 0 ;
		GetAns(v , u) ; 
		val[col[u]] += size[v] ;
		appear[col[u]] = 1 ;
		Update(v , u , 1) ;
		appear[col[u]] = 0 ;
	}
	for(int i = 1 ; i <= colnum ; i ++) val[tcl[i]] = 0 , hap[tcl[i]] = false ;
	colnum = 0 ;
	for(int i = hea[u] ; i ; i = edge[i].Nxt) {
		int v = edge[i].to ; if(vis[v]) continue ;
		tot = size[u] ; tmin = INF ;
		Getroot(v , u) ; Dfs(rt) ;
	}
}
# undef int
int main() {
# define int long long
	n = read() ;
	for(int i = 1 ; i <= n ; i ++) col[i] = read() ;
	for(int i = 1 , u , v ; i < n ; i ++) {
		u = read() , v = read() ;
		add_edge(u , v) ; add_edge(v , u) ;
	}
	tot = n , tmin = INF ;
	Getroot(1 , 1) ; Dfs(rt) ;
	for(int i = 1 ; i <= n ; i ++) printf("%lld\n",Sum[i]) ;
	return 0 ;
}
posted @ 2018-12-03 08:26  beretty  阅读(184)  评论(0编辑  收藏  举报