2021ICPC南京站H. Crystalfly
题目大意
\(n(1\leq n\leq 10^5)\)个节点的树,每个节点 \(i\) 上有 \(a_{i}(1\leq a_{i}\leq 10^9)\) 只蝴蝶和一个时间 \(t_{i}(1\leq t_{i}\leq 3)\) ,在到达一个节点后,可以立即取走该节点上的所有蝴蝶,但每到达一个节点,距离该节点为 \(1\) 的节点上的蝴蝶都会收到惊扰,对于每个收到惊扰的节点 \(j\) ,其上面的蝴蝶数量会在接下来 \(t_{j}\) 时间结束后变为 \(0\) 。现在从节点 \(1\) 开始,每次可以在下一时间开始时走到一个相邻的节点或者保持不动,求可以取走的最多蝴蝶数量。
思路
考虑树形 \(dp\) ,显然可以注意到走要比不动好,之后考虑我们在到达一个顶点后被惊扰的节点当中,我们选择一个节点取走其蝴蝶后,来不及取走第二个 \(t=1\) 或者 \(t=2\) 的节点上的蝴蝶,因为这至少需要走 \(3\) 步,因此我们可以再通过原来的节点返回后取走一个 \(t=3\) 的节点,但代价时第一个到达的被惊扰的节点的所有儿子上的蝴蝶都会变为 \(0\) 。
根据这些分析,我们可以设 \(f[v]\) 为在以 \(v\) 为根的子树中,取走 \(v\) 的情况下,所能够取得的最大蝴蝶数, \(g[v]\) 为在以 \(v\) 为根的子树中,不取走 \(v\) 的情况下,所能够取得的最大蝴蝶数。 \(h[v]\) 为在以 \(v\) 为根的子树中,取走 \(h[v]\) 并且立即返回其父节点的情况下,所能够取得的最大蝴蝶数。
于是显然可以发现对于 \(v\) 的所有儿子 \(to\) :
于是接下来的问题就变为了如何计算 \(g[v]\) ,首先我们只考虑取走一个在到达 \(v\) 后被惊扰的节点的情况,此时会得到:
之后我们再考虑取走一个在到达 \(v\) 后被惊扰的节点 \(i\) 后,再取走一个在到达 \(v\) 后被惊扰的 \(t=3\) 的节点 \(j\) 的情况,我们维护所有 \(to\) 中 \(t=3\) 的节点的 \(a_{to}\) 的最大值 \(max1\) 以及次大值 \(max2\) 都初始为 \(-1\) ,于是有:
最后 \(f[1]\) 即为答案,复杂度 \(O(n)\) 。
代码
#include<bits/stdc++.h>
#include<unordered_map>
#include<unordered_set>
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int, int> PII;
#define all(x) x.begin(),x.end()
//#define int LL
//#define lc p*2+1
//#define rc p*2+2
#define endl '\n'
#define inf 0x3f3f3f3f
#define INF 0x3f3f3f3f3f3f3f3f
#pragma warning(disable : 4996)
#define IOS ios::sync_with_stdio(0),cin.tie(0),cout.tie(0)
const double eps = 1e-8;
const LL mod = 1000000007;
const LL MOD = 998244353;
const int maxn = 1000010;
int T, N;
LL a[maxn], t[maxn], f[maxn], g[maxn], h[maxn];
vector<int>G[maxn];
void add_edge(int from, int to)
{
G[from].push_back(to);
G[to].push_back(from);
}
void dfs(int v, int p)
{
LL fi = -1, se = -1, sumg = 0, mx = 0;
h[v] = a[v];
for (int i = 0; i < G[v].size(); i++)
{
int to = G[v][i];
if (to == p)
continue;
dfs(to, v);
if (t[to] == 3)
{
if (a[to] >= fi)
se = fi, fi = a[to];
else
se = max(se, a[to]);
}
sumg += g[to];
h[v] += g[to];
mx = max(mx, a[to]);
}
g[v] = mx + sumg;
if (fi != -1)
{
for (int i = 0; i < G[v].size(); i++)
{
int to = G[v][i];
if (to == p)
continue;
if (fi == a[to] && t[to] == 3)
{
if (se == -1)
continue;
g[v] = max(g[v], h[to] + sumg - g[to] + se);
}
else
g[v] = max(g[v], h[to] + sumg - g[to] + fi);
}
}
f[v] = g[v] + a[v];
}
void solve()
{
dfs(1, 0);
cout << f[1] << endl;
}
int main()
{
IOS;
cin >> T;
while (T--)
{
cin >> N;
for (int i = 1; i <= N; i++)
{
f[i] = g[i] = h[i] = 0;
G[i].clear();
}
for (int i = 1; i <= N; i++)
cin >> a[i];
for (int i = 1; i <= N; i++)
cin >> t[i];
int u, v;
for (int i = 1; i < N; i++)
{
cin >> u >> v;
add_edge(u, v);
}
solve();
}
return 0;
}