hdu7024 Penguin Love Tour(2021杭电暑假多校5)树形dp
题意
给定一棵\(n\)个点的树,树的每个边有个权值\(w\),每个点有个权值\(p\)。每个点可以把相邻的某一条边边权减\(p\)。最小化直径。(\(1\le n,w\le{10}^5,0\le p\le{10}^5\))
思路
考虑二分答案,设为\(limit\)。那么\(check\)就是每棵子树最大的两条边之和不能超过\(limit\)。设\(dp[u][0]\)为节点\(u\)这棵子树没有使用\(u\)时,某个叶子到\(u\)的最长路径的最小值。\(dp[u][1]\)为已经使用了\(u\)的最小值。那么有:
$ dp[u][0]=max_{v}{min(dp[v][0]+max(0,w_{u,v}-p[v]),dp[v][1]+w_{u,v})} \tag{1}$
\(dp[u][1]=min_{v_0}\{max(min(dp[v_0][0]+max(0,w_{u,v_0}-p[v_0]-p[u]),dp[v_0][1]+max(0,w_{u,v_0}-p[u])),\\{max_{v\not=v_0}\{min(dp[v][0]+max(0,w_{u,v}-p[v]),dp[v][1]+w_{u,v})\})}\} \tag{2}\)
然后又因为对儿子用了\(p[u]\)后最长的儿子一定会在\((1)\)中最长的三个中取,那么求\(dp[u][1]\)只需要枚举\(v_0\)为\((1)\)中最大的三个即可。
代码
#include <bits/stdc++.h>
using namespace std;
using ll=long long;
using pii=pair<int,int>;
using pli=pair<ll,int>;
constexpr ll inf=1e18;
inline char gc() {
static constexpr int BufferSize = 1 << 22 | 5;
static char buf[BufferSize], *p, *q;
static std::streambuf *i = std::cin.rdbuf();
return p == q ? p = buf, q = p + i->sgetn(p, BufferSize), p == q ? EOF : *p++ : *p++;
}
struct Reader {
template <class T>
Reader &operator>>(T &w) {
char c, p = 0;
for (; !std::isdigit(c = gc());) if (c == '-') p = 1;
for (w = c & 15; std::isdigit(c = gc()); w = w * 10 + (c & 15)) ;
if (p) w = -w;
return *this;
}
} fin;
template<int N>
struct Max{
int n=0;
array<pli,N> a;
void insert(pli x) {
if(n!=0)
for(int i=0;i<n;i++) {
if(a[i]<x)
swap(a[i],x);
}
if(n<N) a[n++]=x;
}
void erase(int id) {
for(int i=0;i<n;i++) {
if(a[i].second==id) {
for(int j=i;j<n-1;j++) a[j]=a[j+1];
n--;
break;
}
}
}
ll sum(int cnt) {
cnt=min(cnt,N);
ll ans=0;
for(int i=0;i<cnt;i++) ans+=a[i].first;
return ans;
}
bool vis(int id) {
for(int i=0;i<n;i++)
if(a[i].second==id) return true;
return false;
}
};
void solve() {
int n;
ll L=0,R=0;
fin>>n;
vector<int> p(n+1);
vector<vector<pii>> g(n+1);
for(int i=1;i<=n;i++) fin>>p[i];
for(int i=1,u,v,w;i<=n-1;i++) {
fin>>u>>v>>w;
g[u].push_back({v,w});
g[v].push_back({u,w});
R+=w;
}
vector<ll>dp[2];
ll mid;
bool flag;
function<void(int,int)> dfs=[&](int u,int f) {
int son=0;
Max<3>s;
for(int i=0;i<g[u].size();i++) {
int v=g[u][i].first;
int w=g[u][i].second;
if(v==f) continue;
dfs(v,u);
if(!flag)return;
son++;
s.insert({min(dp[0][v]+max(w-p[v],0),dp[1][v]+w),i});
}
if(son==0) {
dp[0][u]=0;
return;
}
if(s.sum(2)<=mid)
dp[0][u]=s.sum(1);
dp[1][u]=inf;
for(int i=0;i<g[u].size();i++) {
int v=g[u][i].first;
int w=g[u][i].second;
if(v==f || !s.vis(i)) continue;
Max<3> s1=s;
s1.erase(i);
s1.insert({min(dp[0][v]+max(w-p[v]-p[u],0),dp[1][v]+max(w-p[u],0)),i});
if(s1.sum(2)<=mid)
dp[1][u]=min(dp[1][u],s1.sum(1));
}
if(dp[0][u]==inf && dp[1][u]==inf)
flag=false;
};
while(L<R) {
mid=(L+R)/2;
flag=true;
dp[0]=dp[1]=vector<ll>(n+1,inf);
dfs(1,0);
if(flag && (dp[0][1]<=mid || dp[1][1]<=mid)) R=mid;
else L=mid+1;
}
cout<<L<<'\n';
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int T;
fin>>T;
while(T--) solve();
return 0;
}