P2664 树上游戏

路径数颜色,好耶!


xc 有一棵树,树的每个节点有个颜色。给一个长度为 \(n\) 的颜色序列,定义 \(s(i,j)\)\(i\)\(j\) 的颜色数量。以及

\[sum_i=\sum_{j=1}^n s(i, j) \]

现在他想让你求出所有的 \(sum_i\)


如果只要我们求一个点的,这就是 nt 题了。可惜求不得。

处理路径的方式,还要处理所有路径,只有倍增和点分治两种了。这题倍增处理不了几种颜色相加的情况,只能点分治。

点分治就能处理相加的情况了吗?不能。但点分治可以换一种方式规避掉相加。就是算贡献。

对于当前根的贡献自然是好算的,递归时顺便统计就好了。那么点→根的路径上的颜色也可以求出来。

贡献是多少呢?对于每一个点,经过另一个子树的点的每条路径贡献 \(1\) 的颜色总数,所以总贡献就是路径的条数。“该点→根节点”的路径已经固定了,所以我们的路径条数就是另一个点的子树大小。注意子树中不能重复统计该颜色,因为已经统计过了。

这是一组样例:

颜色分别是:

1 4 3 3 3 1 2 2 5

\(8\) 不会在作为 \(7\) 的儿子时统计到贡献,因为统计到 \(8\) 的时候 \(7\) 肯定会被统计到,而他们俩颜色相同,只用算一个,算了 \(7\) 就不用算 \(8\) 了。

注意我们统计的贡献都是根到另一子树的贡献,当前点到根的颜色数我们在统计根答案时已经顺便算出来了。还要加上。

inline void dfs(ll u, ll fa, ll rt)//num表示每种颜色的贡献,z表示贡献总和,d表示当前节点到根节点颜色数,ton表示当前路径颜色出现次数。
{
    si[u] = 1;
    ton[col[u]] ++;
    if(ton[col[u]] == 1) p ++;//当前子树这种颜色的贡献算完了,下面的同种颜色都不用算贡献了
    d[u] = p;
    ans[u] += p;//和根的路径inline void deep(ll u, ll fa, ll p)
{
    ll pd = 0;
    if(ton[col[u]] == 0) 
    {
        p += num[col[u]];
        ton[col[u]] ++;
        pd = 1;
    }
    dep[u] = p;
    for(auto v : e[u])
    {
        if(vis[v] == 1 || v == fa) continue;
        deep(v, u, p);
    }
    if(pd == 1) ton[col[u]] --;
}
    if(u != rt) ans[rt] += p;//根自己的答案
    for(auto v : e[u])
    {
        if(vis[v] == 1 || v == fa) continue;
        dfs(v, u, rt);
        si[u] += si[v];
    }
    if(ton[col[u]] == 1) 
    {
        num[col[u]] += si[u];
        z += si[u];
    }
    ton[col[u]] --;//避免使用memset清空
    if(ton[col[u]] == 0) p --;
}

把所有贡献加起来,再减去节点所在子树的贡献,有多少条路径就加多少次当前点→根节点上的颜色,就是当前节点的答案了。。。。。吗?

如果当前点→根的路径上的颜色在其他子树上又被算了贡献呢?这不就算重了吗?得减得减。于是我们预处理出当前点到根颜色的贡献,把最终答案减去这一部分。

inline void deep(ll u, ll fa, ll p)//dep即为当前点到根颜色的贡献
{
    ll pd = 0;
    if(ton[col[u]] == 0) 
    {
        p += num[col[u]];
        ton[col[u]] ++;
        pd = 1;
    }
    dep[u] = p;
    for(auto v : e[u])
    {
        if(vis[v] == 1 || v == fa) continue;
        deep(v, u, p);
    }
    if(pd == 1) ton[col[u]] --;
}

这就完了吗?我们减了一遍当前子树的颜色贡献,我们又减了一遍到根路径的所有颜色贡献。那么到根路径的当前子树颜色贡献不就被减了两次了吗?得加得加。

inline void pa(ll u, ll fa)//w为当前子树贡献和,val为当前子树每种颜色的贡献,now中存储当前子树节点,bye存储到根路径的当前子树颜色贡献
{
    now[++ cnt] = u;
    ton[col[u]] ++;
    if(ton[col[u]] == 1) w += si[u], val[col[u]] += si[u];
    for(auto v : e[u])
    {
        if(vis[v] == 1 || v == fa) continue;
        pa(v, u);
    }
    ton[col[u]] --;
}
inline void lolo(ll u, ll fa, ll p)
{
    ll pd = 0;
    if(ton[col[u]] == 0) p += val[col[u]], ton[col[u]] ++, pd = 1;
    bye[u] = p;
    for(auto v : e[u])
    {
        if(vis[v] == 1 || v == fa) continue;
        lolo(v, u, p);
    }
    if(pd == 1) ton[col[u]] --;
}

题外话:这么绕的一个东西,我说了一句算贡献来帮我卡常的大佬就全懂了,恐怖如斯。

然后就做完了!

#include <bits/stdc++.h>
#define ll long long
#define forp(i, a, b) for (ll i = (a); i <= (b); i++)
#define forc(i, a, b) for (ll i = (a); i <= (b); i++)
#define pii pair<ll, ll>
#define mp make_pair
using namespace std;
const ll maxn = 2e5 + 6;
ll n, m;
ll col[maxn], ans[maxn];
vector<ll> e[maxn];
ll mx[maxn], si[maxn], rt, Tot;
bool vis[maxn];
inline void get_rt(ll u, ll fa)
{
    si[u]= mx[u] = 1;
    for(auto v : e[u])
    {
        if(vis[v] == 1 || v == fa) continue;
        get_rt(v, u);
        si[u] += si[v]; mx[u] = max(mx[u], si[v]);
    }
    mx[u] = max(mx[u], Tot - si[u]);
    if(mx[u] <= mx[rt]) rt = u;
}
ll ton[maxn], p, num[maxn], z, d[maxn];
inline void dfs(ll u, ll fa, ll rt)
{
    si[u] = 1;
    ton[col[u]] ++;
    if(ton[col[u]] == 1) p ++;
    d[u] = p;
    ans[u] += p;
    if(u != rt) ans[rt] += p;
    for(auto v : e[u])
    {
        if(vis[v] == 1 || v == fa) continue;
        dfs(v, u, rt);
        si[u] += si[v];
    }
    if(ton[col[u]] == 1) 
    {
        num[col[u]] += si[u];
        z += si[u];
    }
    ton[col[u]] --;
    if(ton[col[u]] == 0) p --;
}
ll dep[maxn];
inline void deep(ll u, ll fa, ll p)
{
    ll pd = 0;
    if(ton[col[u]] == 0) 
    {
        p += num[col[u]];
        ton[col[u]] ++;
        pd = 1;
    }
    dep[u] = p;
    for(auto v : e[u])
    {
        if(vis[v] == 1 || v == fa) continue;
        deep(v, u, p);
    }
    if(pd == 1) ton[col[u]] --;
}
ll now[maxn], cnt, w, val[maxn], bye[maxn];
inline void pa(ll u, ll fa)
{
    now[++ cnt] = u;
    ton[col[u]] ++;
    if(ton[col[u]] == 1) w += si[u], val[col[u]] += si[u];
    for(auto v : e[u])
    {
        if(vis[v] == 1 || v == fa) continue;
        pa(v, u);
    }
    ton[col[u]] --;
}
inline void lolo(ll u, ll fa, ll p)
{
    ll pd = 0;
    if(ton[col[u]] == 0) p += val[col[u]], ton[col[u]] ++, pd = 1;
    bye[u] = p;
    for(auto v : e[u])
    {
        if(vis[v] == 1 || v == fa) continue;
        lolo(v, u, p);
    }
    if(pd == 1) ton[col[u]] --;
}
inline void del(ll u, ll fa)
{
    val[col[u]] = bye[u] = num[col[u]] = dep[u] = 0;
    for(auto v : e[u])
    {
        if(vis[v] == 1 || v == fa) continue;
        del(v, u);
    }
}
inline void calc(ll u)
{
    si[u] = 1;num[col[u]] = si[u];ton[col[u]] = 1;// 把根节点加进去
    ans[u] += 1;//自己跟自己
    for(auto v : e[u])
    {
        if(vis[v] == 1) continue;
        p = 1;
        dfs(v, u, u);
        si[u] += si[v];
    }
    z += si[u];
    for(auto v : e[u])
    {
        if(vis[v] == 1) continue;
        deep(v, u, si[u]);
    }
    for(auto v : e[u])
    {
        if(vis[v] == 1) continue;
        cnt = w = 0;
        pa(v, u);
        lolo(v, u, 0);
        for(ll j = 1;j <= cnt;j ++) ans[now[j]] += ((z - w) - (dep[now[j]] - bye[now[j]])) + d[now[j]] * (si[u] - si[v] - 1);//全颜色贡献-全颜色子树贡献 - (根节点路径颜色贡献 - 根节点路径颜色子树贡献)
        del(v, u);
    }
    z = 0;
    ton[col[u]] = 0;num[col[u]] = 0;
}
void get_si(int u, int fa)
{
    si[u] = 1;
    for(auto v : e[u])
    {
        if(v == fa || vis[v] == 1) continue;
        get_si(v, u);
        si[u] += si[v];
    }
}
inline void solve(ll u)
{
    vis[u] = 1;
    calc(u);
    for(auto v : e[u])
    {
        if(vis[v]) continue;
        get_si(v, u);
        Tot = si[v];
        rt = 0;get_rt(v, u);
        solve(rt);
    }
}
signed main()
{
    freopen("text.in", "r", stdin);
    // freopen("text.out", "w[maxn]", stdout);
    ios::sync_with_stdio(false);
    cin.tie(0);cout.tie(0);
    mx[0] = 1919810;
    cin >> n;
    for(ll i = 1;i <= n;i ++) cin >> col[i];
    ll u, v;
    for(ll i = 1;i < n;i ++)
    {
        cin >> u >> v;
        e[u].push_back(v);
        e[v].push_back(u);
    }
    Tot = n;
    get_rt(1, -1);
    solve(rt);
    for(ll i = 1;i <= n;i ++) cout << ans[i] << endl;
    return 0;
}

如果你 TLE 了,请检查以下问题:

  1. long long 开了没?答案可能很大。

  2. 重心找对了没?找错了点分治的复杂度就不存在了。

posted @ 2023-01-11 20:51  _maze  阅读(51)  评论(0编辑  收藏  举报