题解 中心城镇问题

传送门

为啥孩子感觉难点主要在于 DP 啊
指针长剖明明很好写啊

考场上想 DP 的时候并没有想到这个东西是可以合并子树的
于是定义 \(f_{u, d}\) 为点 \(u\) 子树内所有选择的点的深度都 \(\geqslant d\) 的最大价值
转移考虑将 \(v\) 的子树并入当前答案
首先有 \(f_{u, dep_u}=w_u\)
然后分情况讨论:

\[f'_{u, d}=\begin{cases} f_{u, d}+f{v, d}&2(d-dep_u)>k \\ \max\{f'_{u, d+1}, f_{u, d}+f_{v, 2*dep_u+k-d+1}, f_{u, 2*dep_u+k-d+1}+f_{v,d}\}&0<2(d-dep_u)\leqslant k \\ \max\{f'_{u, d+1}, f_{u, dep_u}+f_{v, dep_u+k+1}\}&d=dep_u\end{cases} \]

对照式子的话转移是容易理解的
然后考虑优化
发现是个长链剖分的形式
于是对每条长链动态开空间,时空复杂度就都变成 \(O(n)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 1000010
#define ll long long
#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, k;
int head[N], w[N], size;
struct edge{int to, next;}e[N<<1];
inline void add(int s, int t) {e[++size]={t, head[s]}; head[s]=size;}

namespace force{
	int dep[N], dp[N][110], g[N][110];
	void dfs(int u, int fa) {
		g[u][dep[u]]=dp[u][dep[u]]=w[u];
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			dep[v]=dep[u]+1;
			dfs(v, u);
			for (int d=dep[u]+k; d>=dep[u]; --d) {
				if (2*(d-dep[u])>k) g[u][d]=dp[u][d]+dp[v][d];
				else if (d==dep[u]) g[u][d]=dp[u][dep[u]]+dp[v][dep[u]+k+1];
				else g[u][d]=max(dp[u][d]+dp[v][2*dep[u]+k-d+1], dp[u][2*dep[u]+k-d+1]+dp[v][d]);
				g[u][d]=max(g[u][d], g[u][d+1]);
			}
			for (int i=1; i<=n; ++i) dp[u][i]=g[u][i];
		}
		for (int i=1; i<=n; ++i) dp[u][i]=g[u][i];
	}
	void solve() {
		dep[1]=1; dfs(1, 0);
		printf("%lld\n", dp[1][1]);
		// cout<<"---dp---"<<endl;
		// for (int i=1; i<=n; ++i) {cout<<i<<": "; for (int j=1; j<=n; ++j) cout<<dp[i][j]<<' '; cout<<endl;}
	}
}

namespace task1{
	int dep[N], f[N][110], g[N][110], mdep[N], mson[N];
	void dfs1(int u, int fa) {
		mdep[u]=dep[u];
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			dep[v]=dep[u]+1;
			dfs1(v, u);
			if (mdep[v]>mdep[u]) mdep[u]=mdep[v], mson[u]=v;
		}
	}
	void dfs2(int u, int fa) {
		f[u][dep[u]]=w[u];
		if (!mson[u]) return ;
		dfs2(mson[u], u);
		for (int i=dep[mson[u]]; i<=mdep[mson[u]]; ++i) f[u][i]=f[mson[u]][i];
		f[u][dep[u]]=max(f[u][dep[u]], f[u][dep[u]]+f[u][dep[u]+k+1]);
		f[u][dep[u]]=max(f[u][dep[u]], f[u][dep[u]+1]);
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa||v==mson[u]) continue;
			dfs2(v, u);
			for (int d=dep[u]; d<=mdep[v]; ++d) {
				if (2*(d-dep[u])>k) f[u][d]=f[u][d]+f[v][d];
				else if (d==dep[u]) f[u][d]=f[u][dep[u]]+f[v][dep[u]+k+1];
				else f[u][d]=max(f[u][d]+f[v][2*dep[u]+k-d+1], f[u][2*dep[u]+k-d+1]+f[v][d]);
			}
			for (int i=mdep[v]; i>=dep[u]; --i) f[u][i]=max(f[u][i], f[u][i+1]);
		}
	}
	void solve() {
		dep[1]=1; dfs1(1, 0); dfs2(1, 0);
		printf("%lld\n", f[1][1]);
		// cout<<"---f---"<<endl;
		// for (int i=1; i<=n; ++i) {cout<<i<<": "; for (int j=1; j<=n; ++j) cout<<f[i][j]<<' '; cout<<endl;}
	}
}

namespace task{
	int dep[N], *f[N], mdep[N], mson[N];
	void dfs1(int u, int fa) {
		mdep[u]=dep[u];
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			dep[v]=dep[u]+1;
			dfs1(v, u);
			if (mdep[v]>mdep[u]) mdep[u]=mdep[v], mson[u]=v;
		}
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa||v==mson[u]) continue;
			int t=mdep[v]-dep[v]+5;
			f[v]=new int[t]-dep[v];
			for (int i=0; i<t; ++i) f[v][dep[v]+i]=0;
		}
	}
	void dfs3(int u, int fa) {
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			if (v==mson[u]) f[v]=f[u];
			dfs3(v, u);
		}
	}
	void dfs2(int u, int fa) {
		f[u][dep[u]]=w[u];
		if (!mson[u]) return ;
		dfs2(mson[u], u);
		// for (int i=dep[mson[u]]; i<=mdep[mson[u]]; ++i) f[u][i]=f[mson[u]][i];
		if (dep[u]+k+1<=mdep[u]) f[u][dep[u]]=max(f[u][dep[u]], f[u][dep[u]]+f[u][dep[u]+k+1]);
		f[u][dep[u]]=max(f[u][dep[u]], f[u][dep[u]+1]);
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa||v==mson[u]) continue;
			dfs2(v, u);
			for (int d=dep[u]; d<=mdep[v]; ++d) {
				if (2*(d-dep[u])>k) f[u][d]=f[u][d]+f[v][d];
				else if (d==dep[u]) {
					if (dep[u]+k+1<=mdep[v]) f[u][d]=f[u][dep[u]]+f[v][dep[u]+k+1];
				}
				else {
					if (2*dep[u]+k-d+1<=mdep[v]) f[u][d]=max(f[u][d], f[u][d]+f[v][2*dep[u]+k-d+1]);
					if (2*dep[u]+k-d+1<=mdep[u]) f[u][d]=max(f[u][d], f[u][2*dep[u]+k-d+1]+f[v][d]);
					else f[u][d]=max(f[u][d], f[v][d]);
				}
			}
			for (int i=mdep[v]; i>=dep[u]; --i) f[u][i]=max(f[u][i], f[u][i+1]);
		}
	}
	void solve() {
		dep[1]=1; dfs1(1, 0);
		f[1]=new int[mdep[1]+5]-1;
		for (int i=0; i<mdep[1]+5; ++i) f[1][i+dep[1]]=0;
		dfs3(1, 0); dfs2(1, 0);
		printf("%lld\n", f[1][1]);
		// cout<<"---f---"<<endl;
		// for (int i=1; i<=n; ++i) {cout<<i<<": "; for (int j=1; j<=n; ++j) cout<<f[i][j]<<' '; cout<<endl;}
	}
}

signed main()
{
	freopen("central.in", "r", stdin);
	freopen("central.out", "w", stdout);

	n=read(); k=read();
	memset(head, -1, sizeof(head));
	for (int i=1; i<=n; ++i) w[i]=read();
	for (int i=1,u,v; i<n; ++i) {
		u=read(); v=read();
		add(u, v); add(v, u);
	}
	// force::solve();
	task::solve();

	return 0;
}
posted @ 2022-01-20 10:27  Administrator-09  阅读(1)  评论(0编辑  收藏  举报