题解

首先有一个性质:对于一个点,它的独特点都会分布在(它到整棵树的两个直径端点中较远的那一个端点的路径)上

我们如果以当前点为根,那么下面的那一部分直径就会消除上面的那一部分的直径的部分独特点

注意,如果设转轴为O,那么点O对于点x依然是独特的

当然,直径上还会有一些支链,它们也会消除一部分独特点

为什么不用考虑下面的直径上的支链?

因为下半部分的支链长度是小于等于下半部分的直径长度的(否则就矛盾了)

所以我们可以直接用下半部分的直径来代替它们了

至于如何维护消除操作

我们可以用一个栈和一个桶来简单维护

但是如果我们要计算的点在支链上,且支链上还有支链来影响答案那该怎么办?

这个时候我们就可以借助长链剖分

把父亲长链的栈继承下来

再用递归调用一下solve就可以了

注意,父亲长链的点O必须在每一次递归调用的时候都要存在于栈内

因为有可能会被弹出,这样在下一个儿子就可能统计不到点O

(当然,为了好举例,这里没有画出下面的长链的消除影响)

 

 

代码:

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
inline int gi()
{
	char c;int num=0,flg=1;
	while((c=getchar())<'0'||c>'9')if(c=='-')flg=-1;
	while(c>='0'&&c<='9'){num=num*10+c-48;c=getchar();}
	return num*flg;
}
#define N 200005
int rt;
int fir[N],nxt[2*N],to[2*N],cnt;
int ans[N],dep[N],mxd[N],son[N];
int a[N],b[N],sum,stk[N],top;
void adde(int a,int b){to[++cnt]=b;nxt[cnt]=fir[a];fir[a]=cnt;}
void dfs1(int u,int fa,int &g)
{
	if((dep[u]=dep[fa]+1)>dep[g]) g=u;
	for(int v,p=fir[u];p;p=nxt[p]){
		if((v=to[p])!=fa){
			dep[v]=dep[u]+1;
			dfs1(v,u,g);
		}
	}
}
void dfs2(int u,int fa)
{
	dep[u]=dep[fa]+1;son[u]=mxd[u]=0;
	for(int v,p=fir[u];p;p=nxt[p]){
		if((v=to[p])!=fa){
			dfs2(v,u);
			if(mxd[v]+1>mxd[u]){
				mxd[u]=mxd[v]+1;
				son[u]=v;
			}
		}
	}
}
void Push(int x){stk[++top]=x;if(!b[a[x]]++)sum++;}
void Pop(){if(!--b[a[stk[top--]]])sum--;}
void solve(int u,int fa)
{
	if(son[u]){
		int sec=0;
		for(int v,p=fir[u];p;p=nxt[p])
			if((v=to[p])!=son[u]&&v!=fa)
				sec=max(sec,mxd[v]+1);
		while(top&&dep[stk[top]]>=dep[u]-sec)Pop();
		Push(u);solve(son[u],u);
	}
	while(top&&dep[stk[top]]>=dep[u]-mxd[u])Pop();
	ans[u]=max(ans[u],sum);
	for(int v,p=fir[u];p;p=nxt[p]){
		if((v=to[p])!=son[u]&&v!=fa){
			if(stk[top]!=u)Push(u);
			solve(v,u);
		}
	}
	if(stk[top]==u)Pop();
}
int main()
{
	freopen("star.in","r",stdin);
	freopen("star.out","w",stdout);
	int n,m,i,u,v;
	n=gi();m=gi();
	for(i=1;i<n;i++){u=gi();v=gi();adde(u,v);adde(v,u);}
	for(int i=1;i<=n;i++)a[i]=gi();
	dfs1(1,0,rt);dfs2(rt,0);solve(rt,0);
	dfs1(rt,0,rt);dfs2(rt,0);solve(rt,0);
	for(i=1;i<=n;i++)printf("%d\n",ans[i]);
}

 

再补一个std题解: