(洛谷)P4657 chase
基本思路 :
-
首先令a数组表示该点的权值,c数组表示该点所连接的所有点的权值和。
-
如果我们知道了该点的前驱,那该点的权值就为 $ c[now]-a[pre] $ 。
递进 :
问题在于这是一棵无根树,那我们可以任意定义一点为根,
然后设两个数组 $ s_{i,j} , d_{i,j} $ 分别表示从子树 i 中某个点走到 i 和从 i 走到子树 i 中某个点,放置了 j 次的最大收益。那问题就变成求树的直径了。
注意 :
因为我们需要知道一个点的前驱,所以 s数组 是记录该子树的根节点的, d数组 是不记录该子树的根节点的。
另外,起始点一定会放置磁铁,因为如果逃亡者降落一个点后再跑到其他点放置磁铁,那显然不如直接降落在该点放置磁铁更优。
最后一点,如果是起始点,因为其没有前驱,所以该点的贡献为 $ c[now] $ 。
树的直径 :
因为树的路径不能重复,所以我们可以枚举完一个 $ to $ 节点后先与 $ now $ 的 $ d $ 数组和 $ s $ 数组更新一边答案,再将 $ to $ 节点的答案合并到 $ now $ 的答案中。
时间复杂度 :\(O(nv)\)
- code
#include <bits/stdc++.h>
#define re register
#define int long long
#define db double
#define pir make_pair
using namespace std;
const int maxn=100010;
inline int read() {
int s=0,w=1; char ch=getchar();
while(ch<'0'||ch>'9') { if(ch=='-') w=-1;ch=getchar(); }
while(ch>='0'&&ch<='9') { s=s*10+ch-'0'; ch=getchar(); }
return s*w;
}
int cnt,head[maxn],s[maxn][110],d[maxn][110],a[maxn],c[maxn];
struct EDGE { int nxt,var; } edge[maxn<<1];
inline void add(int a,int b) { edge[++cnt]=(EDGE){head[a],b};head[a]=cnt; }
int n,v,ans; int sl[110],dl[110];
inline void dfs(int now,int pre) {
s[now][1]=c[now];
for(re int i=head[now],to;i;i=edge[i].nxt) {
if((to=edge[i].var)==pre) continue;
dfs(to,now);
int maxs=0,maxd=0;
for(re int i=v;i>=2;i--) {
sl[i]=max(s[to][i],c[now]-a[to]+s[to][i-1]);
dl[i]=max(d[to][i],c[to]-a[now]+d[to][i-1]);
maxs=max(maxs,s[now][v-i]);
maxd=max(maxd,d[now][v-i]);
ans=max(ans,sl[i]+maxd);
ans=max(ans,dl[i]+maxs);
}
maxs=max(maxs,s[now][v-1]);
maxd=max(maxd,d[now][v-1]);
sl[1]=max(s[to][1],c[now]);
dl[1]=max(d[to][1],c[to]-a[now]);
ans=max(ans,sl[1]+maxd);
ans=max(ans,dl[1]+maxs);
s[now][1]=max(s[now][1],sl[1]);
d[now][1]=max(d[now][1],dl[1]);
maxs=max(maxs,s[now][v]);
maxd=max(maxd,d[now][v]);
ans=max(ans,sl[0]+maxd);
ans=max(ans,dl[0]+maxs);
for(re int i=1;i<=v;i++) {
s[now][i]=max(s[now][i],sl[i]);
d[now][i]=max(d[now][i],dl[i]);
maxs=max(maxs,s[now][i]);
maxd=max(maxd,d[now][i]);
}
ans=max(ans,max(maxs,maxd));
}
}
signed main(void) {
// freopen("chase9.in","r",stdin);
n=read(),v=read();
for(re int i=1;i<=n;i++) a[i]=read();
for(re int i=1,u,e;i<n;i++) {
u=read(),e=read();
add(u,e); add(e,u);
c[u]+=a[e]; c[e]+=a[u];
}
if(!v) { printf("0"); return 0; }
dfs(1,0);
printf("%lld",ans);
}