luogu P2664 树上游戏

考虑点分治。

那么现在问题就是如何快速求出跨过分治中心的点对之间的贡献。

我们考虑分治中心到叶节点路径上某种颜色的第一个节点,显然这个点的子树的每一个节点因为该种颜色产生的贡献都为\(1\),我们用\(color[i]\)记录第\(i\)种颜色以此方法产生的贡献。并记\(sum=\sum color[i]\)

现在考虑如何累加答案。对于每一个点\(x\),我们把以它作为一个端点产生的贡献分为两种:

  1. 分治中心到它(不包括分治中心)的所有颜色产生的贡献。
  2. 与它不在分治中心同一棵子树的点产生的贡献。

第一点非常好求,记分治中心到它(不包括分治中心)共有\(p\)种颜色,那么这些颜色产生的贡献就是\(siz[v]+1\),其中\(root\)为分治中心,\(v\)\(x\)的祖先且为\(root\)的儿子。

第二点就要用上刚刚的铺垫了。我们考虑\(sum\)多加了哪些贡献:对于\(p\)里面的所有颜色,\(color[i]\)已经不能产生贡献,另外在跑\(v\)这棵子树之前也应该先把这棵子树对\(color[i]\)的贡献全部删除,跑完后再加进来。

还要注意一些细节(分治中心的对答案的影响有些不同之处)。

代码:

#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>

using namespace std;

typedef long long LL;
const int N=100009;
int n,head[N],cnt,point[N],del[N],siz[N],num,now;
LL col[N],Cnt[N],sum,ans[N],fuck,QWQ,C;
struct Edge
{
	int nxt,to,w;
}g[N*2];

void add(int from,int to)
{
	g[++cnt].nxt=head[from];
	g[cnt].to=to;
	head[from]=cnt;
}

void init()
{
	scanf("%d",&n);
	for (int i=1;i<=n;i++)
		scanf("%d",&point[i]);
	for (int i=1;i<n;i++)
	{
		int x,y;
		scanf("%d %d",&x,&y);
		add(x,y),add(y,x);
	}
}

void DFS(int x,int fa)
{
	siz[x]=1;
	for (int i=head[x];i;i=g[i].nxt)
	{
		int v=g[i].to;
		if(v==fa||del[v])
			continue;
		DFS(v,x);
		siz[x]+=siz[v];
	}
}

int Get_Weight(int x)
{
	DFS(x,-1);
	int k=siz[x]/2,fa=-1;
	while(1)
	{
		int tmp=0;
		for (int i=head[x];i;i=g[i].nxt)
		{
			int v=g[i].to;
			if(v==fa||del[v])
				continue;
			if(siz[tmp]<siz[v])
				tmp=v;
		}
		if(siz[tmp]<=k)
			return x;
		fa=x,x=tmp;
	}
}

void dfs_1(int x,int fa)
{
	siz[x]=1,Cnt[point[x]]++;
	for (int i=head[x];i;i=g[i].nxt)
	{
		int v=g[i].to;
		if(v==fa||del[v])
			continue;
		dfs_1(v,x);
		siz[x]+=siz[v];
	}
	if(Cnt[point[x]]==1)
		col[point[x]]+=siz[x],sum+=siz[x];
	Cnt[point[x]]--;
}

void Modify(int x,int fa,int type)
{
	Cnt[point[x]]++;
	for (int i=head[x];i;i=g[i].nxt)
	{
		int v=g[i].to;
		if(v==fa||del[v])
			continue;
		Modify(v,x,type);
	}
	if(Cnt[point[x]]==1)
		col[point[x]]+=type*siz[x],sum+=type*siz[x];
	Cnt[point[x]]--;
}

void calc(int x,int fa)
{
	Cnt[point[x]]++;
	if(Cnt[point[x]]==1)
		num++,fuck+=col[point[x]];
	ans[x]+=sum-fuck+1LL*num*now-(Cnt[C]?0:QWQ);
	for (int i=head[x];i;i=g[i].nxt)
	{
		int v=g[i].to;
		if(v==fa||del[v])
			continue;
		calc(v,x);
	}
	if(Cnt[point[x]]==1)
		num--,fuck-=col[point[x]];
	Cnt[point[x]]--;
}

void Get_Ans(int x)
{
	dfs_1(x,-1);
	ans[x]+=sum,C=point[x];
	for (int i=head[x];i;i=g[i].nxt)
	{
		int v=g[i].to;
		if(del[v])
			continue;
		Cnt[point[x]]=1,Modify(v,x,-1),Cnt[point[x]]=0;
		QWQ=siz[v],now=siz[x]-siz[v],calc(v,x);
		Cnt[point[x]]=1,Modify(v,x,1),Cnt[point[x]]=0;
	}
	num=0;
	Modify(x,-1,-1);
	//for (int i=1;i<=n;i++)
		//printf("%d ",sum);puts("");
}

void conquer(int x)
{
	int w=Get_Weight(x);
	del[w]=1;
	Get_Ans(w);
	for (int i=head[w];i;i=g[i].nxt)
	{
		int v=g[i].to;
		if(del[v])
			continue;
		conquer(v);
	}
}

void work()
{
	conquer(1);
	for (int i=1;i<=n;i++)
		printf("%lld\n",ans[i]);
}

int main()
{
	init();
	work();
	return 0;
}
posted @ 2020-04-26 06:43  With_penguin  阅读(95)  评论(0编辑  收藏  举报