「SWTR-4」Collecting Coins 题解
挺明显的换根 DP。。
0x01
先考虑一下起点为 \(d\) 的答案该怎么算。
我们发现可以钦定 \(d\) 为树根,设 \(dp_i\) 表示以 \(i\) 为根的子树中,以 \(i\) 为起点最多可以获得的代价。由于有了进入节点的次数限制,我们肯定不能直接加和。不难发现这个限制其实就和点的度数有关,于是我们直接选择 \(i\) 的儿子中,\(dp\) 值最大的 \(k_i-1\) 个就行了,排序即可。
0x02
我们尝试根据起点为 \(d\) 时的求法拓展到其余节点。
这里就是换根 DP 了。我们记 \(f_i\) 表示以 \(i\) 为起点,且 \(i\) 只会往父亲节点移动的最大代价。考虑用 \(f_{fa}\) 去更新 \(f_i\)。如果我们只能向 \(fa\) 移动,那么 \(fa\) 肯定已经走过了一次,又因为我们的 \(fa\) 也在往上走,所以初始时 \(fa\) 走了两次。也就是我们选择 \(i\) 的兄弟节点中 \(dp\) 值最大的 \(k_{fa}-2\) 个节点,将它们的 \(dp\) 值加起来,最后再加上 \(f_{fa}\),就可以得到 \(dp_i\)。
可能有人要问了——为什么 \(fa\) 非要往上走呢?为什么不可以用 \(i\) 的兄弟节点中 \(dp\) 值最大的 \(k_{fa}-1\) 个呢??
因为我们必须要走到 \(d\) 节点至少一次,并且我们是以 \(d\) 为根进行 DP 的,那么 \(fa\) 肯定就要往上走一次。也正是如此,我们在统计 \(d\) 的儿子的时候,才应该像上面这样转移,因为 \(fa\) 就是 \(d\),没有必要往其它地方走了。
0x03
求出了 \(f_i\) 之后,我们就可以统计答案了。
我们枚举起点 \(s\)。首先要到达 \(d\) 至少一次,我们就加上一个 \(f_s\),然后再考虑将 \(k_s\) 给跑满,所以我们还要加上 \(s\) 的儿子中 \(dp\) 值最大的 \(k_s-2\) 个的和。
比较最大值即可。
#include<bits/stdc++.h>
using namespace std;
const int MAXN=1e5+5;
int head[MAXN],nxt[MAXN<<1],to[MAXN<<1],val[MAXN<<1],tot;
void add(int x,int y,int z)
{
to[++tot]=y;
val[tot]=z;
nxt[tot]=head[x];
head[x]=tot;
}
int n,d;
int a[MAXN],in[MAXN];
int dp[MAXN],f[MAXN];
int stk[MAXN],cnt;
bool cmp(int x,int y){ return dp[x]<dp[y]; }
void dfs_first(int x,int fa)
{
for(int i=head[x];i;i=nxt[i])
{
if(to[i]==fa) continue;
dfs_first(to[i],x);
}
int tot=0;
for(int i=head[x];i;i=nxt[i])
{
if(to[i]==fa) continue;
dp[to[i]]+=val[i],stk[++tot]=to[i];
}
sort(stk+1,stk+tot+1,cmp);
for(int i=max(1,tot-a[x]+2);i<=tot;i++) dp[x]+=dp[stk[i]];
}
int maxx;
void dfs(int x,int fa)
{
if(a[x]==1) return;
cnt=0;
for(int i=head[x];i;i=nxt[i])
{
if(to[i]==fa) continue;
stk[++cnt]=to[i];
}
sort(stk+1,stk+cnt+1,cmp);
int now=0;
for(int i=cnt-a[x]+3;i<=cnt;i++) now+=dp[stk[i]];
maxx=max(maxx,f[x]+now);
for(int i=cnt-a[x]+3;i<=cnt;i++) f[stk[i]]=f[x]+now-dp[stk[i]]+dp[stk[max(1,cnt-a[x]+3)-1]];
for(int i=1;i<=cnt-a[x]+2;i++) f[stk[i]]=f[x]+now;
for(int i=head[x];i;i=nxt[i])
{
if(to[i]==fa) continue;
f[to[i]]+=val[i];
dfs(to[i],x);
}
}
int main()
{
// freopen("data.in","r",stdin);
// freopen("data.out","w",stdout);
ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
cin>>n>>d;
for(int i=1;i<n;i++)
{
int x,y,z;
cin>>x>>y>>z;
add(x,y,z),add(y,x,z);
in[x]++,in[y]++;
}
for(int i=1;i<=n;i++) cin>>a[i],a[i]=min(a[i],in[i]+1);
for(int i=head[d];i;i=nxt[i]) dfs_first(to[i],d);
cnt=0;
for(int i=head[d];i;i=nxt[i]) stk[++cnt]=to[i],dp[to[i]]+=val[i];
sort(stk+1,stk+cnt+1,cmp);
for(int i=max(1,cnt-a[d]+2);i<=cnt;i++) maxx+=dp[stk[i]];
for(int i=max(1,cnt-a[d]+2);i<=cnt;i++) f[stk[i]]=maxx-dp[stk[i]]+dp[stk[max(1,cnt-a[d]+2)-1]];
for(int i=1;i<=cnt-a[d]+1;i++) f[stk[i]]=maxx;
for(int i=head[d];i;i=nxt[i]) f[to[i]]+=val[i],dfs(to[i],d);
cout<<maxx;
return 0;
}