【题解】P4074 - [WC2013] 糖果公园
题目大意
给出一棵包含 \(n\) 个结点的无根树。已知这棵树上的每个结点 \(i\) 均有其唯一的“糖果类型” \(c_i\)。定义 \(v_i\) 表示第 \(i\) 个糖果类型的“美味程度”,\(w_i\) 表示第 \(i\) 次品尝某种糖果的“新奇指数”。已知第 \(i\) 次品尝糖果 \(j\) 会增加 \(w_i \times v_j\) 的“愉悦指数”。现在给出 \(q\) 个操作,每次操作可以:
-
将结点 \(x\) 得到糖果类型改为 \(y\)
-
询问结点 \(x, y\) 间路径的愉悦指数之和
\(n, m, q \leq 100000\)
\(1 \leq v_i, w_i \leq 10^6\)
\(1 \leq c_i \leq m\)
对于任意 \(1 < i \leq n\),满足 \(w_i \leq w_{i - 1}\)
解题思路
这道题是树上的带修问题,优先考虑用数据结构维护。因为题目查询树上结点路径间的愉悦指数,所以考虑用树剖或者树上莫队维护。显然树剖做法比树上莫队更加复杂,所以可以先确定这道题使用的算法是 树上带修莫队。
我们先来理解题目给出的条件。第 \(i\) 种糖果的美味指数为 \(v_i\),结点 \(u\) 的糖果类型为 \(c_u\),记为 \(k\),实际上相当于结点 \(u\) 拥有 \(v_{k}\) 的点权,这个点权需要配合新奇指数维护。想要维护愉悦指数,我们还需要维护新奇指数,而维护新奇指数就必须要维护 每种糖果类型出现的次数,这正是树上莫队的典型应用。
我们用 欧拉序 来将树上路径转化成区间。欧拉序的构造方法是深度优先遍历整棵树,第一次和最后一次访问结点 \(u\) 时都把 \(u\) 加入当前欧拉序的末尾。对于树上结点 \(u, v\) 之间的路径,记 \(st_i\) 表示结点 \(i\) 在欧拉序中第一次出现的位置,\(ed_i\) 表示结点 \(i\) 在欧拉序中最后一次出现的位置。假设 \(st_u < st_v\),如果 \(lca(u, v) = u\),说明 \(u, v\) 是祖孙关系,它们对应欧拉序中的区间 \([st_u, st_v]\)。反之,它们对应欧拉序中的区间 \([ed_u, ed_v]\)。注意,区间中出现且仅出现一次的结点才在 \(u, v\) 间的路径上。对于 \(lca(u, v) \neq u\) 的情况,我们还需要把 \(lca(u, v)\) 也一起维护。
接着考虑用树上莫队维护每种糖果类型在路径上出现的次数。假设此时在欧拉序中有一结点 \(u\),设糖果类型 \(c_u\) 的出现次数为 \(t\),则 \(c_u\) 当前的贡献为 \(\sum\limits_{i = 1}^t v_{c_u} \times w_i\),相等于 \(v_{c_u} \times \sum\limits_{i = 1}^t w_i\)。观察第二个式子,我们发现第二项其实可以用一个前缀和来维护。我们对于 \(w_i\) 作前缀和,记 \(s_n = \sum\limits_{i = 1}^n w_i\)。设 \(t_i\) 为糖果类型 \(i\) 的出现次数,那么糖果类型 \(i\) 的贡献为 \(v_i \times s_{t_i}\)。
有了上面的推导,我们就可以方便地用树上带修莫队来维护愉悦指数了。首先预处理出询问区间在欧拉序上对应的区间和 \(w\) 数组的前缀和。设答案为 \(a\),接着对于每次询问,如果左端点移动,那么它影响答案后答案更新为 原本的答案 \(-\) 原本的出现次数 \(\times\) 前缀和 \(+\) 新的出现次数 \(\times\) 前缀和。右端点移动同理。这里代码实现可以使用一个小技巧,我们用 vis[i]
表示结点 \(i\) 对于当前的路径是否存在影响。每次更新答案时若 vis[i] = false
,将当前结点的糖果类型出现次数 \(+ 1\),否则 \(- 1\),相应地按照上式统计它们的影响,最后直接将 vis[i]
取反。
每次维护更改操作的指针时,如果被更改的结点 \(i\) 对路径有影响,也就是 vis[i] = true
,直接令原本糖果类型的出现次数 \(- 1\),新糖果类型的出现次数 \(+ 1\) 并统计相应的影响。反之说明它对当前路径无影响,也就是对答案没有影响,所以不更新答案。最后无论是否影响答案都要修改糖果类型,交换操作给出的糖果类型和结点 \(i\) 的糖果类型,这样下次回溯指针时再遇到这个操作时就可以实现删除操作的效果了。详见代码。
总时间复杂度 \(O(n\sqrt{n})\)。
参考代码
#include <cstdio>
#include <cmath>
#include <iostream>
#include <algorithm>
using namespace std;
const int maxn = 1e5 + 5;
const int maxm = 1e5 + 5;
const int maxv = 1e5 + 5;
int n, m, q;
int atot, btot;
int l, r, k, tot;
int bel[2 * maxn], cnt[maxv];
int f[maxn][20], v[maxn], c[maxn], dep[maxn];
int head[maxn], st[maxn], ed[maxn], p[2 * maxn];
long long cur;
long long w[maxn], ans[maxm];
bool vis[maxn];
// bel[i] -> 位置 i 属于莫队的第 bel[i] 个块
// cnt[i] -> 糖果类型 i 的出现次数
// f[i][j] -> 结点 i 的第 2^j 辈祖先
// dep[i] -> 结点 i 的深度
// st[i], ed[i] -> 结点 i 在欧拉序中的开头和结尾
// cur -> 当前的答案,ans[i] -> 第 i 个询问的答案
// vis[i] -> 结点 i 是否对当前路径存在影响
// 链式前向星
struct node
{
int to, nxt;
} edge[maxn * 2];
// 询问结构体
struct ques
{
int l, r, k;
int lca, id;
bool operator < (const ques& rhs) const
{
if (bel[l] ^ bel[rhs.l])
return bel[l] < bel[rhs.l];
if (bel[r] ^ bel[rhs.r])
return (bel[l] & 1 ? r < rhs.r : r > rhs.r);
return k < rhs.k;
}
} a[maxm];
// 操作结构体
struct option
{
int pos, val;
} b[maxm];
// 链式前向星 -> 加边
void add_edge(int u, int v)
{
tot++;
edge[tot].to = v;
edge[tot].nxt = head[u];
head[u] = tot;
}
// 预处理欧拉序和深度
void dfs(int u, int fa)
{
dep[u] = dep[fa] + 1;
p[++tot] = u;
st[u] = tot;
for (int i = head[u]; i; i = edge[i].nxt)
{
int v = edge[i].to;
if (v != fa)
{
f[v][0] = u;
dfs(v, u);
}
}
p[++tot] = u;
ed[u] = tot;
}
// 求结点 u, v 的 lca
int lca(int u, int v)
{
if (dep[u] < dep[v])
swap(u, v);
int k = 0;
while ((1 << (k + 1)) <= dep[u])
k++;
for (int i = k; i >= 0; i--)
if (dep[f[u][i]] >= dep[v])
u = f[u][i];
if (u == v)
return u;
for (int i = k; i >= 0; i--)
{
if (f[u][i] != f[v][i])
{
u = f[u][i];
v = f[v][i];
}
}
return f[u][0];
}
// 结点 x 对应的糖果类型 c[x] 出现次数 + 1
void add(int x)
{
// v[c[x]] -> 当前糖果类型的美味程度
// w[cnt[c[x]]] -> 当前糖果类型出现次数对应的前缀和
// 两者相乘即为糖果类型 c[x] 对愉悦指数的贡献
cur -= (long long)v[c[x]] * w[cnt[c[x]]];
cnt[c[x]]++;
cur += (long long)v[c[x]] * w[cnt[c[x]]];
}
// 结点 x 对应的糖果类型 c[x] 出现次数 - 1
void del(int x)
{
cur -= (long long)v[c[x]] * w[cnt[c[x]]];
cnt[c[x]]--;
cur += (long long)v[c[x]] * w[cnt[c[x]]];
}
// 分类讨论,更新答案
void work(int x)
{
if (!vis[x])
add(x);
else
del(x);
vis[x] ^= 1;
}
// 维护操作
void update(int x)
{
bool flag = vis[b[x].pos];
if (flag)
del(b[x].pos);
swap(c[b[x].pos], b[x].val);
if (flag)
add(b[x].pos);
}
int main()
{
int opt, x, y;
l = 1, r = 0, k = 0;
scanf("%d%d%d", &n, &m, &q);
for (int i = 1; i <= m; i++)
scanf("%d", &v[i]);
for (int i = 1; i <= n; i++)
{
scanf("%lld", &w[i]);
w[i] += w[i - 1];
}
for (int i = 1; i <= n - 1; i++)
{
scanf("%d%d", &x, &y);
add_edge(x, y);
add_edge(y, x);
}
for (int i = 1; i <= n; i++)
scanf("%d", &c[i]);
int block = pow(n, 0.6667); // 块长取 n^(2/3)
for (int i = 1; i <= 2 * n; i++)
bel[i] = (i - 1) / block + 1;
tot = 0;
dfs(1, 0);
for (int j = 1; (1 << j) <= n; j++)
for (int i = 1; i <= n; i++)
f[i][j] = f[f[i][j - 1]][j - 1];
for (int i = 1; i <= q; i++)
{
scanf("%d", &opt);
if (!opt)
{
btot++;
scanf("%d%d", &b[btot].pos, &b[btot].val);
}
else
{
atot++;
a[atot].id = atot;
a[atot].k = btot;
scanf("%d%d", &a[atot].l, &a[atot].r);
if (st[a[atot].l] > st[a[atot].r])
swap(a[atot].l, a[atot].r);
x = lca(a[atot].l, a[atot].r);
if (x == a[atot].l)
{
a[atot].l = st[a[atot].l];
a[atot].r = st[a[atot].r];
}
else
{
a[atot].l = ed[a[atot].l];
a[atot].r = st[a[atot].r];
a[atot].lca = x;
}
}
}
sort(a + 1, a + atot + 1);
for (int i = 1; i <= atot; i++)
{
while (l < a[i].l)
work(p[l++]);
while (l > a[i].l)
work(p[--l]);
while (r < a[i].r)
work(p[++r]);
while (r > a[i].r)
work(p[r--]);
while (k < a[i].k)
update(++k);
while (k > a[i].k)
update(k--);
if (a[i].lca)
work(a[i].lca);
ans[a[i].id] = cur;
if (a[i].lca)
work(a[i].lca);
}
for (int i = 1; i <= atot; i++)
printf("%lld\n", ans[i]);
return 0;
}