【题解】P3313 - 旅行
题目大意
给定一棵有 \(n\) 个结点的树,树上的每个结点有其对应的点权和“宗教”。现给出每个结点初始的点权和宗教,每次可以:
-
修改某个结点的宗教
-
修改某个结点的点权
-
查询从结点 \(s\) 到结点 \(t\) 的最短路径上,所有宗教与 \(s\) 和 \(t\) 相同的结点的点权总和,保证 \(s\) 和 \(t\) 宗教相同
-
查询从结点 \(s\) 到结点 \(t\) 的最短路径上,所有宗教与 \(s\) 和 \(t\) 相同的结点的点权最大值,保证 \(s\) 和 \(t\) 宗教相同
结点总数 \(n\) 和操作次数 \(m \leq 10^5\),所有结点的点权不超过 \(10^4\),所有宗教的值不超过 \(10^5\)。
核心思路
显然,从修改两个结点之间的路径可以看出,这道题需要使用 树链剖分 求解。但是,这道题在模板的基础上,加上了“宗教”的限制。我们需要想办法减去“宗教”的限制,才能使用树链剖分维护。
较为浅显的一个思路是,我们用一棵线段树,维护 \(dfs\) 序中一段区间内不同宗教的点权总和以及点权最大值。但是,这种方法的弊端是线段树的每个结点都要开一个大小为 \(10^5\) 的数组,显然会 \(MLE\) 。假如用离散化处理的话,宗教总数又不确定,也无法优化空间复杂度。
于是,我们可以想到一个优化的思路:给每一个宗教都开一棵线段树,宗教 \(x\) 的线段树内只有宗教为 \(x\) 的结点。显然,因为线段树内不会有宗教不统一的结点信息,所以我们可以直接用树链剖分来维护线段树的路径总和及最大值。但是,给每一个宗教都开一棵线段树,空间开销还是太大了。故而我们还需要使用 动态开点 的思想来优化空间。
动态开点的主要思想是,因为传统线段树有许多结点在维护时没有作用,所以我们可以只在线段树保留当前有用的结点,其他结点根据需要创建。这样,我们就可以用较小的空间开销来同时维护多棵线段树,也就解决了空间复杂度的问题。
因为我们需要用一个结构体数组来维护多棵线段树,所以我们还需要记录每棵宗教线段树的根结点下标,每个结构体元素需要维护左儿子和右儿子的下标。此处不多赘述动态开点,其他题解已经做了较为清晰的说明。
于是,这道题的主要思路就确定了:给每个宗教都开一棵线段树,宗教 \(x\) 的线段树只保存宗教为 \(x\) 的结点的信息。这样,就可以直接用树链剖分来维护操作。
操作详解
操作一
对于操作一,我们可以在原来宗教的线段树内删除修改的结点,再在新宗教的线段树插入修改的结点。此处删除直接清零信息即可。
操作二
对于操作二,我们先在线段树内删除修改的结点,再重新加入点权被更新过的结点。
操作三
对于操作三,直接用树链剖分维护即可。
操作四
同上。
参考代码
#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;
const int maxn = 1e5 + 5;
const int maxm = 2e5 + 5;
struct Edge
{
int to, nxt;
} edge[maxm];
struct node
{
int lson, rson, sum, val;
} tree[24 * maxn];
int n, q, cnt, tot;
int head[maxn], w[maxn], c[maxn], root[maxn], son[maxn];
int id[maxn], dep[maxn], fa[maxn], size[maxn], top[maxn];
char opt[10];
void add_edge(int u, int v)
{
cnt++;
edge[cnt].to = v;
edge[cnt].nxt = head[u];
head[u] = cnt;
}
void dfs1(int u, int f)
{
fa[u] = f;
dep[u] = dep[f] + 1;
size[u] = 1;
for (int i = head[u]; i; i = edge[i].nxt)
{
if (edge[i].to != f)
{
dfs1(edge[i].to, u);
size[u] += size[edge[i].to];
if (size[edge[i].to] > size[son[u]])
son[u] = edge[i].to;
}
}
}
void dfs2(int u, int t)
{
id[u] = ++cnt;
top[u] = t;
if (son[u])
dfs2(son[u], t);
for (int i = head[u]; i; i = edge[i].nxt)
if (edge[i].to != fa[u] && edge[i].to != son[u])
dfs2(edge[i].to, edge[i].to);
}
void push_up(int k)
{
tree[k].sum = tree[tree[k].lson].sum + tree[tree[k].rson].sum;
tree[k].val = max(tree[tree[k].lson].val, tree[tree[k].rson].val);
}
void update(int &k, int x, int l, int r, int w)
{
if (!k)
k = ++tot;
if (l == r)
{
tree[k].sum = tree[k].val = w;
return;
}
int mid = (l + r) / 2;
if (x <= mid)
update(tree[k].lson, x, l, mid, w);
else
update(tree[k].rson, x, mid + 1, r, w);
push_up(k);
}
void del(int k, int x, int l, int r)
{
if (!k)
return;
if (l == r)
{
tree[k].sum = tree[k].val = 0;
return;
}
int mid = (l + r) / 2;
if (x <= mid)
del(tree[k].lson, x, l, mid);
else
del(tree[k].rson, x, mid + 1, r);
push_up(k);
}
int query_sum(int k, int l, int r, int x, int y)
{
if (!k)
return 0;
if (l >= x && r <= y)
return tree[k].sum;
int mid = (l + r) / 2, sum = 0;
if (x <= mid)
sum += query_sum(tree[k].lson, l, mid, x, y);
if (y > mid)
sum += query_sum(tree[k].rson, mid + 1, r, x, y);
return sum;
}
int query_max(int k, int l, int r, int x, int y)
{
if (!k)
return 0;
if (l >= x && r <= y)
return tree[k].val;
int mid = (l + r) / 2, val = 0;
if (x <= mid)
val = max(val, query_max(tree[k].lson, l, mid, x, y));
if (y > mid)
val = max(val, query_max(tree[k].rson, mid + 1, r, x, y));
return val;
}
int get_sum(int rt, int x, int y)
{
int sum = 0;
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]])
swap(x, y);
sum += query_sum(rt, 1, n, id[top[x]], id[x]);
x = fa[top[x]];
}
if (dep[x] > dep[y])
swap(x, y);
sum += query_sum(rt, 1, n, id[x], id[y]);
return sum;
}
int get_max(int rt, int x, int y)
{
int val = 0;
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]])
swap(x, y);
val = max(val, query_max(rt, 1, n, id[top[x]], id[x]));
x = fa[top[x]];
}
if (dep[x] > dep[y])
swap(x, y);
val = max(val, query_max(rt, 1, n, id[x], id[y]));
return val;
}
int main()
{
int x, y;
scanf("%d%d", &n, &q);
for (int i = 1; i <= n; i++)
scanf("%d%d", &w[i], &c[i]);
for (int i = 1; i <= n - 1; i++)
{
scanf("%d%d", &x, &y);
add_edge(x, y);
add_edge(y, x);
}
cnt = 0;
dfs1(1, 0);
dfs2(1, 0);
for (int i = 1; i <= n; i++)
update(root[c[i]], id[i], 1, n, w[i]);
for (int i = 1; i <= q; i++)
{
scanf("%s", opt);
scanf("%d%d", &x, &y);
if (opt[1] == 'C')
{
del(root[c[x]], id[x], 1, n);
c[x] = y;
update(root[c[x]], id[x], 1, n, w[x]);
}
else if (opt[1] == 'W')
{
del(root[c[x]], id[x], 1, n);
w[x] = y;
update(root[c[x]], id[x], 1, n, w[x]);
}
else if (opt[1] == 'S')
printf("%d\n", get_sum(root[c[x]], x, y));
else
printf("%d\n", get_max(root[c[x]], x, y));
}
return 0;
}