- 一道点分难题
- 首先很自然的想法就是每种颜色的贡献可以分开计算,然后如果你会虚树就可以直接做了
- 点分也差不多,考虑每个分治重心的子树对它的贡献以及它对它子树的贡献
- 首先,处理一个\(cnt\)数组,\(cnt[i]\)表示从重心出发有多少条包含i颜色的路径,具体做法就是dfs,当该颜色第一次出现时就加上当前子树的size,还要记录子树中出现了哪几种颜色,不能每次都枚举所有颜色,显然,对分治重心的贡献就是\(\sum cnt[i]\).
- 接下来计算分治重心对子树内的贡献,比较麻烦,首先对每颗子树求出ct,定义与\(cnt\)一样,每次令\(cnt[cl[i]]=cnt[cl[i]]-ct[cl[i]]\)就是除该子树以外的有颜色i的路径数,还要特别地把根的颜色在该子树内出现次数减去,也就是减去该子树的size,设分治重心除该子树以外的的点数为\(path\),\(tot\)为现在\(cnt\)数组的和,递归该子树,对于每个节点要加上tot,如果该节点颜色在递归中第一次出现,则产生贡献\(path-cnt[cl]\),同时该贡献还会影响其子树,所以将该贡献在递归是下传即可,最后清空数组,继续分治即可
#include<bits/stdc++.h>
using namespace std;
typedef int sign;
typedef long long ll;
#define For(i,a,b) for(register sign i=(sign)a;i<=(sign)b;++i)
#define Fordown(i,a,b) for(register sign i=(sign)a;i>=(sign)b;--i)
const int N=1e5+5;
bool cmax(sign &a,sign b){return (a<b)?a=b,1:0;}
bool cmin(sign &a,sign b){return (a>b)?a=b,1:0;}
template<typename T>T read()
{
T ans=0,f=1;
char ch=getchar();
while(!isdigit(ch)&&ch!='-')ch=getchar();
if(ch=='-')f=-1,ch=getchar();
while(isdigit(ch))ans=(ans<<3)+(ans<<1)+(ch-'0'),ch=getchar();
return ans*f;
}
template<typename T>void write(T x,char y)
{
if(x==0)
{
putchar('0');putchar(y);
return;
}
if(x<0)
{
putchar('-');
x=-x;
}
static char wr[20];
int top=0;
for(;x;x/=10)wr[++top]=x%10+'0';
while(top)putchar(wr[top--]);
putchar(y);
}
void file()
{
#ifndef ONLINE_JUDGE
freopen("2664.in","r",stdin);
freopen("2664.out","w",stdout);
#endif
}
int n,c[N];
int head[N],tt,nex[N<<1],to[N<<1];
void add(int x,int y)
{
++tt,to[tt]=y,nex[tt]=head[x],head[x]=tt;
}
void input()
{
int x,y;
n=read<int>();
For(i,1,n)c[i]=read<int>();
For(i,2,n)
{
x=read<int>(),y=read<int>();
add(x,y),add(y,x);
}
}
const int inf=0x3f3f3f3f;
int sum,size[N],root,min_sz;
bool ban[N];
#define rg register
void get_root(int u,int pre)
{
int mx=0;
for(rg int i=head[u];i;i=nex[i])
{
if(to[i]==pre||ban[to[i]])continue;
get_root(to[i],u);
cmax(mx,size[to[i]]);
}
cmax(mx,sum-size[u]);
if(mx<min_sz)min_sz=mx,root=u;
}
void get_sz(int u,int pre)
{
size[u]=1;
for(rg int i=head[u];i;i=nex[i])
{
if(to[i]==pre||ban[to[i]])continue;
get_sz(to[i],u);
size[u]+=size[to[i]];
}
}
bool apper[N];
int cl[N],col[N],num[N],top,tp;
ll tot,cnt[N],ct[N],ans[N],cct[N];
void dfs(int u,int pre,ll *cnt)
{
if(!apper[c[u]])col[++top]=c[u],apper[c[u]]=true;
if(++num[c[u]]==1)cnt[c[u]]+=size[u];
for(rg int i=head[u];i;i=nex[i])
{
if(to[i]==pre||ban[to[i]])continue;
dfs(to[i],u,cnt);
}
--num[c[u]];
}
ll path;
void modify(int u,int pre,ll las)
{
ll tag=las;
if(++num[c[u]]==1)tag+=path-cnt[c[u]];
ans[u]+=tag+tot;
for(rg int i=head[u];i;i=nex[i])
{
if(to[i]==pre||ban[to[i]])continue;
modify(to[i],u,tag);
}
--num[c[u]];
}
void cal(int u)
{
get_sz(u,0);
tot=top=0;
dfs(u,0,cnt);
For(i,1,top)apper[col[i]]=false;
tp=top;
For(i,1,top)
{
tot+=cnt[cl[i]=col[i]];
cct[cl[i]]=cnt[cl[i]];
}
/* cout<<tot<<' '<<top<<endl;
For(i,1,n)cout<<cnt[i]<<' ';
puts("");*/
ans[u]+=tot;
ll temp=tot;
for(rg int i=head[u];i;i=nex[i])if(!ban[to[i]])
{
num[c[u]]=1,top=0;
dfs(to[i],u,ct);
num[c[u]]=0;
For(j,1,top)apper[col[j]]=false;
cnt[c[u]]-=size[to[i]];
tot-=size[to[i]];
For(j,1,top)
{
cnt[col[j]]-=ct[col[j]];
tot-=ct[col[j]];
}
path=size[u]-size[to[i]];
modify(to[i],u,0);
cnt[c[u]]+=size[to[i]];
tot=temp;
For(j,1,top)
{
cnt[col[j]]=cct[col[j]];
ct[col[j]]=0;
}
}
For(i,1,tp)cnt[cl[i]]=0;
}
void solve(int u)
{
ban[u]=true;
cal(u);
for(rg int i=head[u];i;i=nex[i])
{
if(ban[to[i]])continue;
get_sz(to[i],u);
min_sz=sum=size[to[i]];
get_root(to[i],u);
solve(root);
}
}
void work()
{
sum=min_sz=n;
get_sz(1,0);
get_root(1,0);
solve(1);
For(i,1,n)write(ans[i],'\n');
}
int main()
{
file();
input();
work();
return 0;
}