HDU 4303 Contest 1

说实话,挺复杂的一道题。

我采用栈的方式,DFS在搜索完一个节点的所有子结点后,通过排序,加快计算该结点所有可能的路径:子结点与子结点的连通,子结点与父结点的连通,通过父结点与各祖先结点的连通。同时记录路径数计算。思路清晰就能写出来了。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define LL __int64
const int N=300010;
using namespace std;

struct e{
	int u,v;
	int col;
	int next;
}edge[N*2];
int head[N];
int tot;
int val[N];
struct p{
	LL valsum;
	LL route;
	int col;
}DFST[N],pre,aft;
LL tmp;

void addedge(int u,int v,int col){
	edge[tot].u=u;
	edge[tot].v=v;
	edge[tot].col=col;
	edge[tot].next=head[u];
	head[u]=tot++;
	edge[tot].u=v;
	edge[tot].v=u;
	edge[tot].col=col;
	edge[tot].next=head[v];
	head[v]=tot++;
}

bool cmp(p a, p b){
	if(a.col<b.col) return true;
	return false;
}

void dfs(LL &ans,int parent,int now,int parent_col,LL &route,int pos){
	int k=-1; LL son_val,son_route;
	for(int ei=head[now];ei!=-1;ei=edge[ei].next){
		if(edge[ei].v==parent) continue;
		k++; son_val=son_route=0;
		dfs(son_val,now,edge[ei].v,edge[ei].col,son_route,pos+k);
		DFST[pos+k].valsum=son_val; DFST[pos+k].route=son_route; 
		DFST[pos+k].col=edge[ei].col;
	}
	if(k>=0){
		sort(DFST+pos,DFST+pos+k+1,cmp);
		for(int i=0;i<=k;i++){
			if(parent!=-1)
				tmp+=((LL)DFST[pos+i].valsum+(LL)val[now]*DFST[pos+i].route);
			if(DFST[pos+i].col!=parent_col){
				ans+=((LL)DFST[pos+i].valsum+(LL)val[now]*DFST[pos+i].route);
				route+=(LL)DFST[pos+i].route;
			}
		}
		if(parent!=-1){
			ans+=val[now];
			route++;
		}
		if(DFST[pos+k].col!=DFST[pos].col){
			pre=DFST[pos];
			int c=DFST[pos].col;
			for(int i=1;i<=k;i++){
				if(DFST[pos+i].col==c){
					pre.valsum=pre.valsum+DFST[pos+i].valsum;
					pre.route=pre.route+DFST[pos+i].route;
				}
				else{
					aft=DFST[pos+i];
					int si=i+1;
					while(aft.col==DFST[pos+si].col&&si<=k){
						aft.valsum+=DFST[pos+si].valsum;
						aft.route+=DFST[pos+si].route;
						si++;
					}
					i=si-1;
					tmp+=(pre.route*aft.valsum+aft.route*pre.valsum+(pre.route*aft.route)*val[now]);
					pre.route+=aft.route;
					pre.valsum+=aft.valsum;
					c=aft.col;
				}
			}
		}
	}
	else{
		ans=val[now];
		route=1;
	}
}

int main(){
	int n,u,v,c;
	while(scanf("%d",&n)!=EOF){
		tmp=0;
		for(int i=1;i<=n;i++)
		scanf("%d",&val[i]);
		memset(head,-1,sizeof(head));
		tot=0;
		for(int i=1;i<n;i++){
			scanf("%d%d%d",&u,&v,&c);
			addedge(u,v,c);
		}
		LL ans=0,route=0;
		dfs(ans,-1,1,-1,route,0);//sum,parent,nownode,parent_col,route,beginpos
		printf("%I64d\n",ans+tmp);
	}
	return 0;
}

  

posted @ 2014-10-24 18:49  chenjunjie1994  阅读(153)  评论(0编辑  收藏  举报