P2664 树上游戏
路径数颜色,好耶!
xc 有一棵树,树的每个节点有个颜色。给一个长度为 \(n\) 的颜色序列,定义 \(s(i,j)\) 为 \(i\) 到 \(j\) 的颜色数量。以及
现在他想让你求出所有的 \(sum_i\)。
如果只要我们求一个点的,这就是 nt 题了。可惜求不得。
处理路径的方式,还要处理所有路径,只有倍增和点分治两种了。这题倍增处理不了几种颜色相加的情况,只能点分治。
点分治就能处理相加的情况了吗?不能。但点分治可以换一种方式规避掉相加。就是算贡献。
对于当前根的贡献自然是好算的,递归时顺便统计就好了。那么点→根的路径上的颜色也可以求出来。
贡献是多少呢?对于每一个点,经过另一个子树的点的每条路径贡献 \(1\) 的颜色总数,所以总贡献就是路径的条数。“该点→根节点”的路径已经固定了,所以我们的路径条数就是另一个点的子树大小。注意子树中不能重复统计该颜色,因为已经统计过了。
这是一组样例:
颜色分别是:
1 4 3 3 3 1 2 2 5
\(8\) 不会在作为 \(7\) 的儿子时统计到贡献,因为统计到 \(8\) 的时候 \(7\) 肯定会被统计到,而他们俩颜色相同,只用算一个,算了 \(7\) 就不用算 \(8\) 了。
注意我们统计的贡献都是根到另一子树的贡献,当前点到根的颜色数我们在统计根答案时已经顺便算出来了。还要加上。
inline void dfs(ll u, ll fa, ll rt)//num表示每种颜色的贡献,z表示贡献总和,d表示当前节点到根节点颜色数,ton表示当前路径颜色出现次数。
{
si[u] = 1;
ton[col[u]] ++;
if(ton[col[u]] == 1) p ++;//当前子树这种颜色的贡献算完了,下面的同种颜色都不用算贡献了
d[u] = p;
ans[u] += p;//和根的路径inline void deep(ll u, ll fa, ll p)
{
ll pd = 0;
if(ton[col[u]] == 0)
{
p += num[col[u]];
ton[col[u]] ++;
pd = 1;
}
dep[u] = p;
for(auto v : e[u])
{
if(vis[v] == 1 || v == fa) continue;
deep(v, u, p);
}
if(pd == 1) ton[col[u]] --;
}
if(u != rt) ans[rt] += p;//根自己的答案
for(auto v : e[u])
{
if(vis[v] == 1 || v == fa) continue;
dfs(v, u, rt);
si[u] += si[v];
}
if(ton[col[u]] == 1)
{
num[col[u]] += si[u];
z += si[u];
}
ton[col[u]] --;//避免使用memset清空
if(ton[col[u]] == 0) p --;
}
把所有贡献加起来,再减去节点所在子树的贡献,有多少条路径就加多少次当前点→根节点上的颜色,就是当前节点的答案了。。。。。吗?
如果当前点→根的路径上的颜色在其他子树上又被算了贡献呢?这不就算重了吗?得减得减。于是我们预处理出当前点到根颜色的贡献,把最终答案减去这一部分。
inline void deep(ll u, ll fa, ll p)//dep即为当前点到根颜色的贡献
{
ll pd = 0;
if(ton[col[u]] == 0)
{
p += num[col[u]];
ton[col[u]] ++;
pd = 1;
}
dep[u] = p;
for(auto v : e[u])
{
if(vis[v] == 1 || v == fa) continue;
deep(v, u, p);
}
if(pd == 1) ton[col[u]] --;
}
这就完了吗?我们减了一遍当前子树的颜色贡献,我们又减了一遍到根路径的所有颜色贡献。那么到根路径的当前子树颜色贡献不就被减了两次了吗?得加得加。
inline void pa(ll u, ll fa)//w为当前子树贡献和,val为当前子树每种颜色的贡献,now中存储当前子树节点,bye存储到根路径的当前子树颜色贡献
{
now[++ cnt] = u;
ton[col[u]] ++;
if(ton[col[u]] == 1) w += si[u], val[col[u]] += si[u];
for(auto v : e[u])
{
if(vis[v] == 1 || v == fa) continue;
pa(v, u);
}
ton[col[u]] --;
}
inline void lolo(ll u, ll fa, ll p)
{
ll pd = 0;
if(ton[col[u]] == 0) p += val[col[u]], ton[col[u]] ++, pd = 1;
bye[u] = p;
for(auto v : e[u])
{
if(vis[v] == 1 || v == fa) continue;
lolo(v, u, p);
}
if(pd == 1) ton[col[u]] --;
}
题外话:这么绕的一个东西,我说了一句算贡献来帮我卡常的大佬就全懂了,恐怖如斯。
然后就做完了!
#include <bits/stdc++.h>
#define ll long long
#define forp(i, a, b) for (ll i = (a); i <= (b); i++)
#define forc(i, a, b) for (ll i = (a); i <= (b); i++)
#define pii pair<ll, ll>
#define mp make_pair
using namespace std;
const ll maxn = 2e5 + 6;
ll n, m;
ll col[maxn], ans[maxn];
vector<ll> e[maxn];
ll mx[maxn], si[maxn], rt, Tot;
bool vis[maxn];
inline void get_rt(ll u, ll fa)
{
si[u]= mx[u] = 1;
for(auto v : e[u])
{
if(vis[v] == 1 || v == fa) continue;
get_rt(v, u);
si[u] += si[v]; mx[u] = max(mx[u], si[v]);
}
mx[u] = max(mx[u], Tot - si[u]);
if(mx[u] <= mx[rt]) rt = u;
}
ll ton[maxn], p, num[maxn], z, d[maxn];
inline void dfs(ll u, ll fa, ll rt)
{
si[u] = 1;
ton[col[u]] ++;
if(ton[col[u]] == 1) p ++;
d[u] = p;
ans[u] += p;
if(u != rt) ans[rt] += p;
for(auto v : e[u])
{
if(vis[v] == 1 || v == fa) continue;
dfs(v, u, rt);
si[u] += si[v];
}
if(ton[col[u]] == 1)
{
num[col[u]] += si[u];
z += si[u];
}
ton[col[u]] --;
if(ton[col[u]] == 0) p --;
}
ll dep[maxn];
inline void deep(ll u, ll fa, ll p)
{
ll pd = 0;
if(ton[col[u]] == 0)
{
p += num[col[u]];
ton[col[u]] ++;
pd = 1;
}
dep[u] = p;
for(auto v : e[u])
{
if(vis[v] == 1 || v == fa) continue;
deep(v, u, p);
}
if(pd == 1) ton[col[u]] --;
}
ll now[maxn], cnt, w, val[maxn], bye[maxn];
inline void pa(ll u, ll fa)
{
now[++ cnt] = u;
ton[col[u]] ++;
if(ton[col[u]] == 1) w += si[u], val[col[u]] += si[u];
for(auto v : e[u])
{
if(vis[v] == 1 || v == fa) continue;
pa(v, u);
}
ton[col[u]] --;
}
inline void lolo(ll u, ll fa, ll p)
{
ll pd = 0;
if(ton[col[u]] == 0) p += val[col[u]], ton[col[u]] ++, pd = 1;
bye[u] = p;
for(auto v : e[u])
{
if(vis[v] == 1 || v == fa) continue;
lolo(v, u, p);
}
if(pd == 1) ton[col[u]] --;
}
inline void del(ll u, ll fa)
{
val[col[u]] = bye[u] = num[col[u]] = dep[u] = 0;
for(auto v : e[u])
{
if(vis[v] == 1 || v == fa) continue;
del(v, u);
}
}
inline void calc(ll u)
{
si[u] = 1;num[col[u]] = si[u];ton[col[u]] = 1;// 把根节点加进去
ans[u] += 1;//自己跟自己
for(auto v : e[u])
{
if(vis[v] == 1) continue;
p = 1;
dfs(v, u, u);
si[u] += si[v];
}
z += si[u];
for(auto v : e[u])
{
if(vis[v] == 1) continue;
deep(v, u, si[u]);
}
for(auto v : e[u])
{
if(vis[v] == 1) continue;
cnt = w = 0;
pa(v, u);
lolo(v, u, 0);
for(ll j = 1;j <= cnt;j ++) ans[now[j]] += ((z - w) - (dep[now[j]] - bye[now[j]])) + d[now[j]] * (si[u] - si[v] - 1);//全颜色贡献-全颜色子树贡献 - (根节点路径颜色贡献 - 根节点路径颜色子树贡献)
del(v, u);
}
z = 0;
ton[col[u]] = 0;num[col[u]] = 0;
}
void get_si(int u, int fa)
{
si[u] = 1;
for(auto v : e[u])
{
if(v == fa || vis[v] == 1) continue;
get_si(v, u);
si[u] += si[v];
}
}
inline void solve(ll u)
{
vis[u] = 1;
calc(u);
for(auto v : e[u])
{
if(vis[v]) continue;
get_si(v, u);
Tot = si[v];
rt = 0;get_rt(v, u);
solve(rt);
}
}
signed main()
{
freopen("text.in", "r", stdin);
// freopen("text.out", "w[maxn]", stdout);
ios::sync_with_stdio(false);
cin.tie(0);cout.tie(0);
mx[0] = 1919810;
cin >> n;
for(ll i = 1;i <= n;i ++) cin >> col[i];
ll u, v;
for(ll i = 1;i < n;i ++)
{
cin >> u >> v;
e[u].push_back(v);
e[v].push_back(u);
}
Tot = n;
get_rt(1, -1);
solve(rt);
for(ll i = 1;i <= n;i ++) cout << ans[i] << endl;
return 0;
}
如果你 TLE 了,请检查以下问题:
-
long long 开了没?答案可能很大。
-
重心找对了没?找错了点分治的复杂度就不存在了。