luogu P2664 树上游戏
考虑点分治。
那么现在问题就是如何快速求出跨过分治中心的点对之间的贡献。
我们考虑分治中心到叶节点路径上某种颜色的第一个节点,显然这个点的子树的每一个节点因为该种颜色产生的贡献都为\(1\),我们用\(color[i]\)记录第\(i\)种颜色以此方法产生的贡献。并记\(sum=\sum color[i]\)。
现在考虑如何累加答案。对于每一个点\(x\),我们把以它作为一个端点产生的贡献分为两种:
- 分治中心到它(不包括分治中心)的所有颜色产生的贡献。
- 与它不在分治中心同一棵子树的点产生的贡献。
第一点非常好求,记分治中心到它(不包括分治中心)共有\(p\)种颜色,那么这些颜色产生的贡献就是\(siz[v]+1\),其中\(root\)为分治中心,\(v\)为\(x\)的祖先且为\(root\)的儿子。
第二点就要用上刚刚的铺垫了。我们考虑\(sum\)多加了哪些贡献:对于\(p\)里面的所有颜色,\(color[i]\)已经不能产生贡献,另外在跑\(v\)这棵子树之前也应该先把这棵子树对\(color[i]\)的贡献全部删除,跑完后再加进来。
还要注意一些细节(分治中心的对答案的影响有些不同之处)。
代码:
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long LL;
const int N=100009;
int n,head[N],cnt,point[N],del[N],siz[N],num,now;
LL col[N],Cnt[N],sum,ans[N],fuck,QWQ,C;
struct Edge
{
int nxt,to,w;
}g[N*2];
void add(int from,int to)
{
g[++cnt].nxt=head[from];
g[cnt].to=to;
head[from]=cnt;
}
void init()
{
scanf("%d",&n);
for (int i=1;i<=n;i++)
scanf("%d",&point[i]);
for (int i=1;i<n;i++)
{
int x,y;
scanf("%d %d",&x,&y);
add(x,y),add(y,x);
}
}
void DFS(int x,int fa)
{
siz[x]=1;
for (int i=head[x];i;i=g[i].nxt)
{
int v=g[i].to;
if(v==fa||del[v])
continue;
DFS(v,x);
siz[x]+=siz[v];
}
}
int Get_Weight(int x)
{
DFS(x,-1);
int k=siz[x]/2,fa=-1;
while(1)
{
int tmp=0;
for (int i=head[x];i;i=g[i].nxt)
{
int v=g[i].to;
if(v==fa||del[v])
continue;
if(siz[tmp]<siz[v])
tmp=v;
}
if(siz[tmp]<=k)
return x;
fa=x,x=tmp;
}
}
void dfs_1(int x,int fa)
{
siz[x]=1,Cnt[point[x]]++;
for (int i=head[x];i;i=g[i].nxt)
{
int v=g[i].to;
if(v==fa||del[v])
continue;
dfs_1(v,x);
siz[x]+=siz[v];
}
if(Cnt[point[x]]==1)
col[point[x]]+=siz[x],sum+=siz[x];
Cnt[point[x]]--;
}
void Modify(int x,int fa,int type)
{
Cnt[point[x]]++;
for (int i=head[x];i;i=g[i].nxt)
{
int v=g[i].to;
if(v==fa||del[v])
continue;
Modify(v,x,type);
}
if(Cnt[point[x]]==1)
col[point[x]]+=type*siz[x],sum+=type*siz[x];
Cnt[point[x]]--;
}
void calc(int x,int fa)
{
Cnt[point[x]]++;
if(Cnt[point[x]]==1)
num++,fuck+=col[point[x]];
ans[x]+=sum-fuck+1LL*num*now-(Cnt[C]?0:QWQ);
for (int i=head[x];i;i=g[i].nxt)
{
int v=g[i].to;
if(v==fa||del[v])
continue;
calc(v,x);
}
if(Cnt[point[x]]==1)
num--,fuck-=col[point[x]];
Cnt[point[x]]--;
}
void Get_Ans(int x)
{
dfs_1(x,-1);
ans[x]+=sum,C=point[x];
for (int i=head[x];i;i=g[i].nxt)
{
int v=g[i].to;
if(del[v])
continue;
Cnt[point[x]]=1,Modify(v,x,-1),Cnt[point[x]]=0;
QWQ=siz[v],now=siz[x]-siz[v],calc(v,x);
Cnt[point[x]]=1,Modify(v,x,1),Cnt[point[x]]=0;
}
num=0;
Modify(x,-1,-1);
//for (int i=1;i<=n;i++)
//printf("%d ",sum);puts("");
}
void conquer(int x)
{
int w=Get_Weight(x);
del[w]=1;
Get_Ans(w);
for (int i=head[w];i;i=g[i].nxt)
{
int v=g[i].to;
if(del[v])
continue;
conquer(v);
}
}
void work()
{
conquer(1);
for (int i=1;i<=n;i++)
printf("%lld\n",ans[i]);
}
int main()
{
init();
work();
return 0;
}
由于博主比较菜,所以有很多东西待学习,大部分文章会持续更新,另外如果有出错或者不周之处,欢迎大家在评论中指出!