洛谷 P2664 树上游戏

lrb有一棵树,树的每个节点有个颜色。给一个长度为n的颜色序列,定义\(s(i,j)\)\(i\)\(j\)的颜色数量。以及

\[sum_i = \sum_{j=1}^{n}s(i,j) \]

现在他想让你求出所有的\(sum[i]\)

这题真是难,点分治神题

我们考虑一个性质,对于一个点\(i\),如果它的颜色在到根的路径中是第一次出现,那么对于和\(i\)不在一个子树的点\(j\),对\(j\)都有\(i\)的子树大小\(size_i\)的贡献

然后有了这个性质,就好做了

找完重心后预处理出来实际的\(size\),用\(sum\)来记录所有点的贡献,\(s\)是这个颜色的贡献

而我们不是用点去更新答案,是用颜色来更新答案,所以要枚举子树,去掉这个子树的贡献来统计答案

于是再有\(X\)表示除了这个子树的点数和,\(co\)表示这个点到根的颜色数

然后记录下这个点到根的所有颜色的\(s\)的和,\(s\)是要被减去的

那么\(ans+=sum-s+co\times X\),然后单独更新一下根就是\(ans+=sum-s_{c_{rt}}+size_{rt}\)

Code

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
const int N = 1e5;
using namespace std;
int n,c[N + 5],size[N + 5],maxp[N + 5],rt,su,vis[N + 5],cnt[N + 5];
long long sum,s[N + 5],ros,X,ans[N + 5];
vector <int> d[N + 5];
void get_rt(int u,int fa)
{
    size[u] = 1;
    maxp[u] = 0;
    vector <int>::iterator it;
    for (it = d[u].begin();it != d[u].end();it++)
    {
        int v = (*it);
        if (v == fa || vis[v])
            continue;
        get_rt(v,u);
        size[u] += size[v];
        maxp[u] = max(maxp[u],size[v]);
    }
    maxp[u] = max(maxp[u],su - size[u]);
    if (maxp[u] < maxp[rt])
        rt = u;
}
void get_size(int u,int fa)
{
    size[u] = 1;
    vector <int>::iterator it;
    for (it = d[u].begin();it != d[u].end();it++)
    {
        int v = (*it);
        if (v == fa || vis[v])  
            continue;
        get_size(v,u);
        size[u] += size[v];
    }
}
void dfs(int u,int fa,int w)
{
    cnt[c[u]]++;
    if (cnt[c[u]] == 1)
    {
        s[c[u]] += w * size[u];
        sum += w * size[u];
    }
    if (!cnt[c[rt]])
        ros += w;
    vector <int>::iterator it;
    for (it = d[u].begin();it != d[u].end();it++)
    {
        int v = (*it);
        if (v == fa || vis[v])
            continue;
        dfs(v,u,w);
    }
    cnt[c[u]]--;
}
void upd(int u,int fa,int co,int su)
{
    cnt[c[u]]++;
    if (cnt[c[u]] == 1)
    {
        co++;
        su += s[c[u]];
    }
    ans[u] += sum - su + co * X;
    if (!cnt[c[rt]])
        ans[u] += ros;
    vector <int>::iterator it;
    for (it = d[u].begin();it != d[u].end();it++)
    {
        int v = (*it);
        if (v == fa || vis[v])  
            continue;
        upd(v,u,co,su);
    }
    cnt[c[u]]--;
}
void calc(int u)
{
    vector <int>::iterator it;
    for (it = d[u].begin();it != d[u].end();it++)
    {
        int v = (*it);
        if (vis[v])
            continue;
        dfs(v,u,1);
    }
    for (it = d[u].begin();it != d[u].end();it++)
    {
        int v = (*it);
        if (vis[v])
            continue;
        dfs(v,u,-1);
        X = size[u] - size[v];
        upd(v,0,0,0);
        dfs(v,u,1);
    }
    ans[u] += sum - s[c[u]] + size[u];
    for (it = d[u].begin();it != d[u].end();it++)
    {
        int v = (*it);
        if (vis[v])
            continue;
        dfs(v,u,-1);
    }
}
void solve(int u)
{
    vis[u] = 1;
    ros = 1;
    get_size(u,0);
    calc(u);
    vector <int>::iterator it;
    for (it = d[u].begin();it != d[u].end();it++)
    {
        int v = (*it);
        if (vis[v])
            continue;
        maxp[0] = N;
        su = size[v];
        rt = 0;
        get_rt(v,0);
        solve(rt);
    }
}
int main()
{
    scanf("%d",&n);
    for (int i = 1;i <= n;i++)
        scanf("%d",&c[i]);
    int u,v;
    for (int i = 1;i < n;i++)
    {
        scanf("%d%d",&u,&v);
        d[u].push_back(v);
        d[v].push_back(u);
    }
    su = n;
    maxp[0] = N;
    get_rt(1,0);
    get_size(rt,0);
    solve(rt);
    for (int i = 1;i <= n;i++)
        printf("%lld\n",ans[i]);
    return 0;
}
posted @ 2020-06-08 20:23  eee_hoho  阅读(102)  评论(0编辑  收藏  举报