题解 中心城镇问题
为啥孩子感觉难点主要在于 DP 啊
指针长剖明明很好写啊
考场上想 DP 的时候并没有想到这个东西是可以合并子树的
于是定义 \(f_{u, d}\) 为点 \(u\) 子树内所有选择的点的深度都 \(\geqslant d\) 的最大价值
转移考虑将 \(v\) 的子树并入当前答案
首先有 \(f_{u, dep_u}=w_u\)
然后分情况讨论:
\[f'_{u, d}=\begin{cases} f_{u, d}+f{v, d}&2(d-dep_u)>k \\ \max\{f'_{u, d+1}, f_{u, d}+f_{v, 2*dep_u+k-d+1}, f_{u, 2*dep_u+k-d+1}+f_{v,d}\}&0<2(d-dep_u)\leqslant k \\ \max\{f'_{u, d+1}, f_{u, dep_u}+f_{v, dep_u+k+1}\}&d=dep_u\end{cases}
\]
对照式子的话转移是容易理解的
然后考虑优化
发现是个长链剖分的形式
于是对每条长链动态开空间,时空复杂度就都变成 \(O(n)\) 了
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 1000010
#define ll long long
#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, k;
int head[N], w[N], size;
struct edge{int to, next;}e[N<<1];
inline void add(int s, int t) {e[++size]={t, head[s]}; head[s]=size;}
namespace force{
int dep[N], dp[N][110], g[N][110];
void dfs(int u, int fa) {
g[u][dep[u]]=dp[u][dep[u]]=w[u];
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa) continue;
dep[v]=dep[u]+1;
dfs(v, u);
for (int d=dep[u]+k; d>=dep[u]; --d) {
if (2*(d-dep[u])>k) g[u][d]=dp[u][d]+dp[v][d];
else if (d==dep[u]) g[u][d]=dp[u][dep[u]]+dp[v][dep[u]+k+1];
else g[u][d]=max(dp[u][d]+dp[v][2*dep[u]+k-d+1], dp[u][2*dep[u]+k-d+1]+dp[v][d]);
g[u][d]=max(g[u][d], g[u][d+1]);
}
for (int i=1; i<=n; ++i) dp[u][i]=g[u][i];
}
for (int i=1; i<=n; ++i) dp[u][i]=g[u][i];
}
void solve() {
dep[1]=1; dfs(1, 0);
printf("%lld\n", dp[1][1]);
// cout<<"---dp---"<<endl;
// for (int i=1; i<=n; ++i) {cout<<i<<": "; for (int j=1; j<=n; ++j) cout<<dp[i][j]<<' '; cout<<endl;}
}
}
namespace task1{
int dep[N], f[N][110], g[N][110], mdep[N], mson[N];
void dfs1(int u, int fa) {
mdep[u]=dep[u];
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa) continue;
dep[v]=dep[u]+1;
dfs1(v, u);
if (mdep[v]>mdep[u]) mdep[u]=mdep[v], mson[u]=v;
}
}
void dfs2(int u, int fa) {
f[u][dep[u]]=w[u];
if (!mson[u]) return ;
dfs2(mson[u], u);
for (int i=dep[mson[u]]; i<=mdep[mson[u]]; ++i) f[u][i]=f[mson[u]][i];
f[u][dep[u]]=max(f[u][dep[u]], f[u][dep[u]]+f[u][dep[u]+k+1]);
f[u][dep[u]]=max(f[u][dep[u]], f[u][dep[u]+1]);
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa||v==mson[u]) continue;
dfs2(v, u);
for (int d=dep[u]; d<=mdep[v]; ++d) {
if (2*(d-dep[u])>k) f[u][d]=f[u][d]+f[v][d];
else if (d==dep[u]) f[u][d]=f[u][dep[u]]+f[v][dep[u]+k+1];
else f[u][d]=max(f[u][d]+f[v][2*dep[u]+k-d+1], f[u][2*dep[u]+k-d+1]+f[v][d]);
}
for (int i=mdep[v]; i>=dep[u]; --i) f[u][i]=max(f[u][i], f[u][i+1]);
}
}
void solve() {
dep[1]=1; dfs1(1, 0); dfs2(1, 0);
printf("%lld\n", f[1][1]);
// cout<<"---f---"<<endl;
// for (int i=1; i<=n; ++i) {cout<<i<<": "; for (int j=1; j<=n; ++j) cout<<f[i][j]<<' '; cout<<endl;}
}
}
namespace task{
int dep[N], *f[N], mdep[N], mson[N];
void dfs1(int u, int fa) {
mdep[u]=dep[u];
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa) continue;
dep[v]=dep[u]+1;
dfs1(v, u);
if (mdep[v]>mdep[u]) mdep[u]=mdep[v], mson[u]=v;
}
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa||v==mson[u]) continue;
int t=mdep[v]-dep[v]+5;
f[v]=new int[t]-dep[v];
for (int i=0; i<t; ++i) f[v][dep[v]+i]=0;
}
}
void dfs3(int u, int fa) {
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa) continue;
if (v==mson[u]) f[v]=f[u];
dfs3(v, u);
}
}
void dfs2(int u, int fa) {
f[u][dep[u]]=w[u];
if (!mson[u]) return ;
dfs2(mson[u], u);
// for (int i=dep[mson[u]]; i<=mdep[mson[u]]; ++i) f[u][i]=f[mson[u]][i];
if (dep[u]+k+1<=mdep[u]) f[u][dep[u]]=max(f[u][dep[u]], f[u][dep[u]]+f[u][dep[u]+k+1]);
f[u][dep[u]]=max(f[u][dep[u]], f[u][dep[u]+1]);
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa||v==mson[u]) continue;
dfs2(v, u);
for (int d=dep[u]; d<=mdep[v]; ++d) {
if (2*(d-dep[u])>k) f[u][d]=f[u][d]+f[v][d];
else if (d==dep[u]) {
if (dep[u]+k+1<=mdep[v]) f[u][d]=f[u][dep[u]]+f[v][dep[u]+k+1];
}
else {
if (2*dep[u]+k-d+1<=mdep[v]) f[u][d]=max(f[u][d], f[u][d]+f[v][2*dep[u]+k-d+1]);
if (2*dep[u]+k-d+1<=mdep[u]) f[u][d]=max(f[u][d], f[u][2*dep[u]+k-d+1]+f[v][d]);
else f[u][d]=max(f[u][d], f[v][d]);
}
}
for (int i=mdep[v]; i>=dep[u]; --i) f[u][i]=max(f[u][i], f[u][i+1]);
}
}
void solve() {
dep[1]=1; dfs1(1, 0);
f[1]=new int[mdep[1]+5]-1;
for (int i=0; i<mdep[1]+5; ++i) f[1][i+dep[1]]=0;
dfs3(1, 0); dfs2(1, 0);
printf("%lld\n", f[1][1]);
// cout<<"---f---"<<endl;
// for (int i=1; i<=n; ++i) {cout<<i<<": "; for (int j=1; j<=n; ++j) cout<<f[i][j]<<' '; cout<<endl;}
}
}
signed main()
{
freopen("central.in", "r", stdin);
freopen("central.out", "w", stdout);
n=read(); k=read();
memset(head, -1, sizeof(head));
for (int i=1; i<=n; ++i) w[i]=read();
for (int i=1,u,v; i<n; ++i) {
u=read(); v=read();
add(u, v); add(v, u);
}
// force::solve();
task::solve();
return 0;
}