E74 树形DP P4657 [CEOI2017] Chase
视频链接:E74 树形DP P4657 [CEOI2017] Chase_哔哩哔哩_bilibili
P4657 [CEOI2017] Chase - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
// 树形DP O(n*m) #include <bits/stdc++.h> #define LL long long using namespace std; const int N=100010,M=110; int idx,head[N]; struct E{int to,ne;}e[N<<1]; void add(int x,int y){ e[++idx]={y,head[x]},head[x]=idx; } int n,m; LL a[N],s[N],f[N][M],g[N][M],ans; void dfs(int u,int fa){ vector<int> stk; for(int i=head[u];i;i=e[i].ne){ int v=e[i].to; if(v==fa) continue; dfs(v,u); stk.push_back(v); } for(int i=1;i<=m;++i) f[u][i]=s[u],g[u][i]=s[u]-a[fa]; for(int v:stk){ for(int i=0;i<=m;++i) ans=max(ans,f[u][i]+g[v][m-i]); for(int i=1;i<=m;++i) f[u][i]=max(f[u][i],max(f[v][i],f[v][i-1]+s[u]-a[v])), g[u][i]=max(g[u][i],max(g[v][i],g[v][i-1]+s[u]-a[fa])); } for(int i=1;i<=m;++i) f[u][i]=s[u],g[u][i]=s[u]-a[fa]; reverse(stk.begin(),stk.end()); for(int v:stk){ for(int i=0;i<=m;++i) ans=max(ans,f[u][i]+g[v][m-i]); for(int i=1;i<=m;++i) f[u][i]=max(f[u][i],max(f[v][i],f[v][i-1]+s[u]-a[v])), g[u][i]=max(g[u][i],max(g[v][i],g[v][i-1]+s[u]-a[fa])); } } int main(){ scanf("%d%d",&n,&m); for(int i=1;i<=n;++i)scanf("%lld",&a[i]); for(int i=1,x,y;i<n;++i){ scanf("%d%d",&x,&y); add(x,y),add(y,x); s[x]+=a[y],s[y]+=a[x]; } dfs(1,0); printf("%lld\n",ans); }
// 树形DP O(n*m) #include <bits/stdc++.h> #define LL long long using namespace std; const int N=100010,M=110; int idx,head[N]; struct E{int to,ne;}e[N<<1]; void add(int x,int y){ e[++idx]={y,head[x]},head[x]=idx; } int n,m,stk[N]; LL a[N],s[N],f[N][M],g[N][M],ans; void dfs(int u,int fa){ for(int i=head[u];i;i=e[i].ne){ int v=e[i].to; if(v!=fa) dfs(v,u); } int top=0; for(int i=head[u];i;i=e[i].ne){ int v=e[i].to; if(v!=fa) stk[++top]=v; } for(int i=1;i<=m;++i) f[u][i]=s[u],g[u][i]=0; for(int i=1;i<=top;i++){ int v=stk[i]; for(int j=0;j<=m;++j) ans=max(ans,f[u][j]+g[v][m-j]); for(int j=1;j<=m;++j) f[u][j]=max(f[u][j],max(f[v][j],f[v][j-1]+s[u]-a[v])), g[u][j]=max(g[u][j],max(g[v][j],g[v][j-1]+s[u]-a[fa])); } for(int i=1;i<=m;++i) f[u][i]=s[u],g[u][i]=0; for(int i=top;i>=1;i--){ int v=stk[i]; for(int j=0;j<=m;++j) ans=max(ans,f[u][j]+g[v][m-j]); for(int j=1;j<=m;++j) f[u][j]=max(f[u][j],max(f[v][j],f[v][j-1]+s[u]-a[v])), g[u][j]=max(g[u][j],max(g[v][j],g[v][j-1]+s[u]-a[fa])); } } int main(){ scanf("%d%d",&n,&m); for(int i=1;i<=n;++i)scanf("%lld",&a[i]); for(int i=1,x,y;i<n;++i){ scanf("%d%d",&x,&y); add(x,y),add(y,x); s[x]+=a[y],s[y]+=a[x]; } dfs(1,0); printf("%lld\n",ans); }