洛谷 P2664 树上游戏
lrb有一棵树,树的每个节点有个颜色。给一个长度为n的颜色序列,定义\(s(i,j)\)为\(i\)到\(j\)的颜色数量。以及
\[sum_i = \sum_{j=1}^{n}s(i,j)
\]
现在他想让你求出所有的\(sum[i]\)
这题真是难,点分治神题
我们考虑一个性质,对于一个点\(i\),如果它的颜色在到根的路径中是第一次出现,那么对于和\(i\)不在一个子树的点\(j\),对\(j\)都有\(i\)的子树大小\(size_i\)的贡献
然后有了这个性质,就好做了
找完重心后预处理出来实际的\(size\),用\(sum\)来记录所有点的贡献,\(s\)是这个颜色的贡献
而我们不是用点去更新答案,是用颜色来更新答案,所以要枚举子树,去掉这个子树的贡献来统计答案
于是再有\(X\)表示除了这个子树的点数和,\(co\)表示这个点到根的颜色数
然后记录下这个点到根的所有颜色的\(s\)的和,\(s\)是要被减去的
那么\(ans+=sum-s+co\times X\),然后单独更新一下根就是\(ans+=sum-s_{c_{rt}}+size_{rt}\)
Code
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
const int N = 1e5;
using namespace std;
int n,c[N + 5],size[N + 5],maxp[N + 5],rt,su,vis[N + 5],cnt[N + 5];
long long sum,s[N + 5],ros,X,ans[N + 5];
vector <int> d[N + 5];
void get_rt(int u,int fa)
{
size[u] = 1;
maxp[u] = 0;
vector <int>::iterator it;
for (it = d[u].begin();it != d[u].end();it++)
{
int v = (*it);
if (v == fa || vis[v])
continue;
get_rt(v,u);
size[u] += size[v];
maxp[u] = max(maxp[u],size[v]);
}
maxp[u] = max(maxp[u],su - size[u]);
if (maxp[u] < maxp[rt])
rt = u;
}
void get_size(int u,int fa)
{
size[u] = 1;
vector <int>::iterator it;
for (it = d[u].begin();it != d[u].end();it++)
{
int v = (*it);
if (v == fa || vis[v])
continue;
get_size(v,u);
size[u] += size[v];
}
}
void dfs(int u,int fa,int w)
{
cnt[c[u]]++;
if (cnt[c[u]] == 1)
{
s[c[u]] += w * size[u];
sum += w * size[u];
}
if (!cnt[c[rt]])
ros += w;
vector <int>::iterator it;
for (it = d[u].begin();it != d[u].end();it++)
{
int v = (*it);
if (v == fa || vis[v])
continue;
dfs(v,u,w);
}
cnt[c[u]]--;
}
void upd(int u,int fa,int co,int su)
{
cnt[c[u]]++;
if (cnt[c[u]] == 1)
{
co++;
su += s[c[u]];
}
ans[u] += sum - su + co * X;
if (!cnt[c[rt]])
ans[u] += ros;
vector <int>::iterator it;
for (it = d[u].begin();it != d[u].end();it++)
{
int v = (*it);
if (v == fa || vis[v])
continue;
upd(v,u,co,su);
}
cnt[c[u]]--;
}
void calc(int u)
{
vector <int>::iterator it;
for (it = d[u].begin();it != d[u].end();it++)
{
int v = (*it);
if (vis[v])
continue;
dfs(v,u,1);
}
for (it = d[u].begin();it != d[u].end();it++)
{
int v = (*it);
if (vis[v])
continue;
dfs(v,u,-1);
X = size[u] - size[v];
upd(v,0,0,0);
dfs(v,u,1);
}
ans[u] += sum - s[c[u]] + size[u];
for (it = d[u].begin();it != d[u].end();it++)
{
int v = (*it);
if (vis[v])
continue;
dfs(v,u,-1);
}
}
void solve(int u)
{
vis[u] = 1;
ros = 1;
get_size(u,0);
calc(u);
vector <int>::iterator it;
for (it = d[u].begin();it != d[u].end();it++)
{
int v = (*it);
if (vis[v])
continue;
maxp[0] = N;
su = size[v];
rt = 0;
get_rt(v,0);
solve(rt);
}
}
int main()
{
scanf("%d",&n);
for (int i = 1;i <= n;i++)
scanf("%d",&c[i]);
int u,v;
for (int i = 1;i < n;i++)
{
scanf("%d%d",&u,&v);
d[u].push_back(v);
d[v].push_back(u);
}
su = n;
maxp[0] = N;
get_rt(1,0);
get_size(rt,0);
solve(rt);
for (int i = 1;i <= n;i++)
printf("%lld\n",ans[i]);
return 0;
}