Loading

【题解】P3313 - 旅行

题目大意

题目链接

给定一棵有 \(n\) 个结点的树,树上的每个结点有其对应的点权和“宗教”。现给出每个结点初始的点权和宗教,每次可以:

  • 修改某个结点的宗教

  • 修改某个结点的点权

  • 查询从结点 \(s\) 到结点 \(t\) 的最短路径上,所有宗教与 \(s\)\(t\) 相同的结点的点权总和,保证 \(s\)\(t\) 宗教相同

  • 查询从结点 \(s\) 到结点 \(t\) 的最短路径上,所有宗教与 \(s\)\(t\) 相同的结点的点权最大值,保证 \(s\)\(t\) 宗教相同

结点总数 \(n\) 和操作次数 \(m \leq 10^5\),所有结点的点权不超过 \(10^4\),所有宗教的值不超过 \(10^5\)

核心思路

显然,从修改两个结点之间的路径可以看出,这道题需要使用 树链剖分 求解。但是,这道题在模板的基础上,加上了“宗教”的限制。我们需要想办法减去“宗教”的限制,才能使用树链剖分维护。

较为浅显的一个思路是,我们用一棵线段树,维护 \(dfs\) 序中一段区间内不同宗教的点权总和以及点权最大值。但是,这种方法的弊端是线段树的每个结点都要开一个大小为 \(10^5\) 的数组,显然会 \(MLE\) 。假如用离散化处理的话,宗教总数又不确定,也无法优化空间复杂度。

于是,我们可以想到一个优化的思路:给每一个宗教都开一棵线段树,宗教 \(x\) 的线段树内只有宗教为 \(x\) 的结点。显然,因为线段树内不会有宗教不统一的结点信息,所以我们可以直接用树链剖分来维护线段树的路径总和及最大值。但是,给每一个宗教都开一棵线段树,空间开销还是太大了。故而我们还需要使用 动态开点 的思想来优化空间。

动态开点的主要思想是,因为传统线段树有许多结点在维护时没有作用,所以我们可以只在线段树保留当前有用的结点,其他结点根据需要创建。这样,我们就可以用较小的空间开销来同时维护多棵线段树,也就解决了空间复杂度的问题。

因为我们需要用一个结构体数组来维护多棵线段树,所以我们还需要记录每棵宗教线段树的根结点下标,每个结构体元素需要维护左儿子和右儿子的下标。此处不多赘述动态开点,其他题解已经做了较为清晰的说明。

于是,这道题的主要思路就确定了:给每个宗教都开一棵线段树,宗教 \(x\) 的线段树只保存宗教为 \(x\) 的结点的信息。这样,就可以直接用树链剖分来维护操作。

操作详解

操作一

对于操作一,我们可以在原来宗教的线段树内删除修改的结点,再在新宗教的线段树插入修改的结点。此处删除直接清零信息即可。

操作二

对于操作二,我们先在线段树内删除修改的结点,再重新加入点权被更新过的结点。

操作三

对于操作三,直接用树链剖分维护即可。

操作四

同上。

参考代码

#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;

const int maxn = 1e5 + 5;
const int maxm = 2e5 + 5;

struct Edge
{
	int to, nxt;
} edge[maxm];

struct node
{
	int lson, rson, sum, val;
} tree[24 * maxn];

int n, q, cnt, tot;
int head[maxn], w[maxn], c[maxn], root[maxn], son[maxn];
int id[maxn], dep[maxn], fa[maxn], size[maxn], top[maxn];
char opt[10];

void add_edge(int u, int v)
{
	cnt++;
	edge[cnt].to = v;
	edge[cnt].nxt = head[u];
	head[u] = cnt;
}

void dfs1(int u, int f)
{
	fa[u] = f;
	dep[u] = dep[f] + 1;
	size[u] = 1;
	for (int i = head[u]; i; i = edge[i].nxt)
	{
		if (edge[i].to != f)
		{
			dfs1(edge[i].to, u);
			size[u] += size[edge[i].to];
			if (size[edge[i].to] > size[son[u]])
				son[u] = edge[i].to;
		}
	}
}

void dfs2(int u, int t)
{
	id[u] = ++cnt;
	top[u] = t;
	if (son[u])
		dfs2(son[u], t);
	for (int i = head[u]; i; i = edge[i].nxt)
		if (edge[i].to != fa[u] && edge[i].to != son[u])
			dfs2(edge[i].to, edge[i].to);
}

void push_up(int k)
{
	tree[k].sum = tree[tree[k].lson].sum + tree[tree[k].rson].sum;
	tree[k].val = max(tree[tree[k].lson].val, tree[tree[k].rson].val);
}

void update(int &k, int x, int l, int r, int w)
{
	if (!k)
		k = ++tot;
	if (l == r)
	{
		tree[k].sum = tree[k].val = w;
		return;
	}
	int mid = (l + r) / 2;
	if (x <= mid)
		update(tree[k].lson, x, l, mid, w);
	else
		update(tree[k].rson, x, mid + 1, r, w);
	push_up(k);
}

void del(int k, int x, int l, int r)
{
	if (!k)
		return;
	if (l == r)
	{
		tree[k].sum = tree[k].val = 0;
		return;
	}
	int mid = (l + r) / 2;
	if (x <= mid)
		del(tree[k].lson, x, l, mid);
	else
		del(tree[k].rson, x, mid + 1, r);
	push_up(k);
}

int query_sum(int k, int l, int r, int x, int y)
{
	if (!k)
		return 0;
	if (l >= x && r <= y)
		return tree[k].sum;
	int mid = (l + r) / 2, sum = 0;
	if (x <= mid)
		sum += query_sum(tree[k].lson, l, mid, x, y);
	if (y > mid)
		sum += query_sum(tree[k].rson, mid + 1, r, x, y);
	return sum;
}

int query_max(int k, int l, int r, int x, int y)
{
	if (!k)
		return 0;
	if (l >= x && r <= y)
		return tree[k].val;
	int mid = (l + r) / 2, val = 0;
	if (x <= mid)
		val = max(val, query_max(tree[k].lson, l, mid, x, y));
	if (y > mid)
		val = max(val, query_max(tree[k].rson, mid + 1, r, x, y));
	return val;
}

int get_sum(int rt, int x, int y)
{
	int sum = 0;
	while (top[x] != top[y])
	{
		if (dep[top[x]] < dep[top[y]])
			swap(x, y);
		sum += query_sum(rt, 1, n, id[top[x]], id[x]);
		x = fa[top[x]];
	}
	if (dep[x] > dep[y])
		swap(x, y);
	sum += query_sum(rt, 1, n, id[x], id[y]);
	return sum;
}

int get_max(int rt, int x, int y)
{
	int val = 0;
	while (top[x] != top[y])
	{
		if (dep[top[x]] < dep[top[y]])
			swap(x, y);
		val = max(val, query_max(rt, 1, n, id[top[x]], id[x]));
		x = fa[top[x]];
	}
	if (dep[x] > dep[y])
		swap(x, y);
	val = max(val, query_max(rt, 1, n, id[x], id[y]));
	return val; 
}

int main()
{
	int x, y;
	scanf("%d%d", &n, &q);
	for (int i = 1; i <= n; i++)
		scanf("%d%d", &w[i], &c[i]);
	for (int i = 1; i <= n - 1; i++)
	{
		scanf("%d%d", &x, &y);
		add_edge(x, y);
		add_edge(y, x);
	}
	cnt = 0;
	dfs1(1, 0);
	dfs2(1, 0);
	for (int i = 1; i <= n; i++)
		update(root[c[i]], id[i], 1, n, w[i]);
	for (int i = 1; i <= q; i++)
	{
		scanf("%s", opt);
		scanf("%d%d", &x, &y);
		if (opt[1] == 'C')
		{
			del(root[c[x]], id[x], 1, n);
			c[x] = y;
			update(root[c[x]], id[x], 1, n, w[x]);
		}
		else if (opt[1] == 'W')
		{
			del(root[c[x]], id[x], 1, n);
			w[x] = y;
			update(root[c[x]], id[x], 1, n, w[x]);
		}
		else if (opt[1] == 'S')
			printf("%d\n", get_sum(root[c[x]], x, y));
		else
			printf("%d\n", get_max(root[c[x]], x, y));
	}
	return 0;	
} 
posted @ 2021-07-24 23:25  kymru  阅读(35)  评论(0编辑  收藏  举报