2021ICPC南京站H. Crystalfly

题目大意

\(n(1\leq n\leq 10^5)\)个节点的树,每个节点 \(i\) 上有 \(a_{i}(1\leq a_{i}\leq 10^9)\) 只蝴蝶和一个时间 \(t_{i}(1\leq t_{i}\leq 3)\) ,在到达一个节点后,可以立即取走该节点上的所有蝴蝶,但每到达一个节点,距离该节点为 \(1\) 的节点上的蝴蝶都会收到惊扰,对于每个收到惊扰的节点 \(j\) ,其上面的蝴蝶数量会在接下来 \(t_{j}\) 时间结束后变为 \(0\) 。现在从节点 \(1\) 开始,每次可以在下一时间开始时走到一个相邻的节点或者保持不动,求可以取走的最多蝴蝶数量。

思路

考虑树形 \(dp\) ,显然可以注意到走要比不动好,之后考虑我们在到达一个顶点后被惊扰的节点当中,我们选择一个节点取走其蝴蝶后,来不及取走第二个 \(t=1\) 或者 \(t=2\) 的节点上的蝴蝶,因为这至少需要走 \(3\) 步,因此我们可以再通过原来的节点返回后取走一个 \(t=3\) 的节点,但代价时第一个到达的被惊扰的节点的所有儿子上的蝴蝶都会变为 \(0\)
根据这些分析,我们可以设 \(f[v]\) 为在以 \(v\) 为根的子树中,取走 \(v\) 的情况下,所能够取得的最大蝴蝶数, \(g[v]\) 为在以 \(v\) 为根的子树中,不取走 \(v\) 的情况下,所能够取得的最大蝴蝶数。 \(h[v]\) 为在以 \(v\) 为根的子树中,取走 \(h[v]\) 并且立即返回其父节点的情况下,所能够取得的最大蝴蝶数。
于是显然可以发现对于 \(v\) 的所有儿子 \(to\)

\[h[v]=a_{v}+\sum g[to] \]

\[f[v]=a_{v}+g[v] \]

于是接下来的问题就变为了如何计算 \(g[v]\) ,首先我们只考虑取走一个在到达 \(v\) 后被惊扰的节点的情况,此时会得到:

\[g[v]=g[to]+max\{a_{to}\} \]

之后我们再考虑取走一个在到达 \(v\) 后被惊扰的节点 \(i\) 后,再取走一个在到达 \(v\) 后被惊扰的 \(t=3\) 的节点 \(j\) 的情况,我们维护所有 \(to\)\(t=3\) 的节点的 \(a_{to}\) 的最大值 \(max1\) 以及次大值 \(max2\) 都初始为 \(-1\) ,于是有:

\[g[v]=\left\{ \begin{array}{**lr**} max(g[v],h[i]+\sum g[to]-g[i]+max2), & t_{i}=3,a_{i}=max1,max2\neq -1 \\ max(g[v],h[i]+\sum g[to]-g[i]+max1), & else\\ \end{array} \right. \]

最后 \(f[1]\) 即为答案,复杂度 \(O(n)\)

代码

#include<bits/stdc++.h>
#include<unordered_map>
#include<unordered_set>
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int, int> PII;
#define all(x) x.begin(),x.end()
//#define int LL
//#define lc p*2+1
//#define rc p*2+2
#define endl '\n'
#define inf 0x3f3f3f3f
#define INF 0x3f3f3f3f3f3f3f3f
#pragma warning(disable : 4996)
#define IOS ios::sync_with_stdio(0),cin.tie(0),cout.tie(0)
const double eps = 1e-8;
const LL mod = 1000000007;
const LL MOD = 998244353;
const int maxn = 1000010;

int T, N;
LL a[maxn], t[maxn], f[maxn], g[maxn], h[maxn];
vector<int>G[maxn];

void add_edge(int from, int to) 
{
    G[from].push_back(to);
    G[to].push_back(from);
}

void dfs(int v, int p)
{
	LL fi = -1, se = -1, sumg = 0, mx = 0;
	h[v] = a[v];
	for (int i = 0; i < G[v].size(); i++)
	{
		int to = G[v][i];
		if (to == p)
			continue;
		dfs(to, v);
		if (t[to] == 3)
		{
			if (a[to] >= fi)
				se = fi, fi = a[to];
			else
				se = max(se, a[to]);
		}
		sumg += g[to];
		h[v] += g[to];
		mx = max(mx, a[to]);
	}
	g[v] = mx + sumg;
	if (fi != -1)
	{
		for (int i = 0; i < G[v].size(); i++)
		{
			int to = G[v][i];
			if (to == p)
				continue;
			if (fi == a[to] && t[to] == 3)
			{
				if (se == -1)
					continue;
				g[v] = max(g[v], h[to] + sumg - g[to] + se);
			}
			else
				g[v] = max(g[v], h[to] + sumg - g[to] + fi);
		}
	}
	f[v] = g[v] + a[v];
}

void solve()
{
	dfs(1, 0);
	cout << f[1] << endl;
}

int main()
{
	IOS;
	cin >> T;
	while (T--)
	{
		cin >> N;
		for (int i = 1; i <= N; i++)
		{
			f[i] = g[i] = h[i] = 0;
			G[i].clear();
		}
		for (int i = 1; i <= N; i++)
			cin >> a[i];
		for (int i = 1; i <= N; i++)
			cin >> t[i];
		int u, v;
		for (int i = 1; i < N; i++)
		{
			cin >> u >> v;
			add_edge(u, v);
		}
		solve();
	}

	return 0;
}
posted @ 2022-03-17 00:30  Prgl  阅读(179)  评论(0编辑  收藏  举报