2020暑假牛客多校9 B - Groundhog and Apple Tree (树形dp)
2020暑假牛客多校9 B - Groundhog and Apple Tree (树形dp)
题目大意:
给一个树,走每条边会减hp, 走到点会加hp,原地等待也会加hp, 问最少原地等待时间使得能够遍历所有点。每条边最多走两次。
题解:
首先每条边最多走两次那也就dfs一遍树的过程,既然所有点都要走,那关键也就在与每次先走哪个子树,即去考虑遍历子树的顺序。
用time
表示子树需要最小的等待时间,hp
表示遍历子树能够得到的hp。那么对于一个子树有以下几种情况:
(1) hp > time : 这时候把子树遍历一遍hp会增加,那肯定先处理这一类子树
(2) hp < time: hp会减少 ,hp无法满足遍历所需等待时间
- 所以可以得到排序规则:
先处理hp > time的子树,再处理hp < time的子树,对于hp > time的子树再按time从小到大排序,因为大家都能增加hp那我为什么不把需要time大的子树往后放,等前面time小的子树处理完,hp多增加一些后再处理time大的子树。这样肯定更优。对于hp < time的子树, 那么我可以把这些子树按hp从大到小排序。
代码:
#include<bits/stdc++.h>
using namespace std;
#define rep(i, a, n) for(int i = a; i <= n; ++ i);
#define per(i, a, n) for(int i = n; i >= a; -- i);
typedef long long ll;
const int N = 2e6+ 5;
const int mod = 998244353;
const double Pi = acos(- 1.0);
const int INF = 0x3f3f3f3f;
const int G = 3, Gi = 332748118;
ll qpow(ll a, ll b) { ll res = 1; while(b){ if(b & 1) res = (res * a) % mod; a = (a * a) % mod; b >>= 1;} return res; }
ll gcd(ll a, ll b) { return b ? gcd(b, a % b) : a; }
ll lcm(ll a, ll b) { return a * b / gcd(a, b);}
bool cmp(int a, int b){ return a > b;}
//
int T, n;
ll val[N];
struct node1{
ll hp, time;
}dp[N];
int head[N], cnt = 0;
struct node{
int to, nxt;ll c;
}edge[N << 1];
void add(int u, int v, ll w){
edge[cnt].to = v, edge[cnt].c = w, edge[cnt].nxt = head[u], head[u] = cnt ++;
edge[cnt].to = u, edge[cnt].c = w, edge[cnt].nxt = head[v], head[v] = cnt ++;
}
bool cmp1(node1 a, node1 b){
if((a.hp > a.time) ^ (b.hp > b.time)) return (a.hp > a.time);
if(a.hp > a.time) return a.time < b.time;
return a.hp > b.hp;
// if(a.hp > a.time){
// if(b.hp < b.time) return true;
// else return a.time < b.time;
// }
// else{
// if(b.time > b.hp) return a.hp > b.hp;
// else return false;
// }
}
void dfs(int u, int pre){
vector<node1> sol;
for(int i = head[u]; i != -1; i = edge[i].nxt){
int v = edge[i].to; ll w = edge[i].c;
if(v == pre) continue;
dfs(v, u);
if(w >= dp[v].hp) dp[v].time += 2 * w - dp[v].hp, dp[v].hp = 0;
else dp[v].time += w, dp[v].hp -= w;
sol.push_back(dp[v]);
}
sort(sol.begin(), sol.end(), cmp1);
int tt = sol.size();
ll minn = val[u], thp = val[u];
for(int i = 0; i < tt; ++ i){
ll hp = sol[i].hp, time = sol[i].time;
minn = min(minn, thp - time);
thp += hp - time;
}
if(minn >= 0) dp[u].time = 0, dp[u].hp = thp;
else dp[u].time -= minn, dp[u].hp = thp - minn;
}
int main()
{
scanf("%d",&T);
while(T --){
scanf("%d",&n);
cnt = 0;
for(int i = 1; i <= n; ++ i) {
dp[i].time = dp[i].hp = 0;
head[i] = -1;
scanf("%lld",&val[i]);
}
for(int i = 1; i < n; ++ i){
int x, y; ll z; scanf("%d%d%lld",&x,&y,&z);
add(x, y, z);
}
dfs(1, 0);
printf("%lld\n",dp[1].time);
}
return 0;
}