题解 park/chase

传送门

这题考试的时候觉得时间复杂度假了,\(n \geqslant 1000\)的部分直接瞎写了个特殊性质上去,结果假的时间复杂度能有60pts……

  • 比较大的数组无论如何不要直接全部memset!如果在写部分分,考虑用多少memset多少
    memset真的可以把一份74pts代码卡成30pts的 memset一个1e9的int数组要0.2s long long要0.7s

首先70pts可以枚举起点,每次跑一遍dfs
\(g[i][j]\) 表示以 \(i\) 为起点,撒 \(j\) 次面包屑得到的最大收益即可
这部分代码:

#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define ll long long 
#define ld long double
#define usd unsigned
#define ull unsigned long long
//#define int long long 

#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
char buf[1<<21], *p1=buf, *p2=buf;
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, V;
int p[N], head[N], size;
struct edge{int to, next; bool vis;}e[N<<1];
inline void add(int s, int t) {edge* k=&e[++size]; k->to=t; k->next=head[s]; head[s]=size;}

namespace force{
	ll ans;
	bool none[N];
	void dfs(int u, int fa, int v2, ll sum) {
		if (v2<=0) {ans=max(ans, sum); return ;}
		bool cge=0;
		if (!none[u]) none[u]=1, cge=1;
		for (int i=head[u],v; i; i=e[i].next) {
			v = e[i].to;
			if (v!=fa) dfs(v, u, v2, sum);
		}
		for (int i=head[u],v; i; i=e[i].next) if (!none[e[i].to]) {sum+=p[e[i].to]; none[e[i].to]=1; e[i].vis=1;}
		for (int i=head[u],v; i; i=e[i].next) {
			v = e[i].to;
			if (v!=fa) dfs(v, u, v2-1, sum);
		}
		for (int i=head[u],v; i; i=e[i].next) if (e[i].vis) {e[i].vis=0; none[e[i].to]=0;}
		if (cge) none[u]=0;
	}
	void solve() {
		//for (int i=1; i<=n; ++i) {
			//memset(none, 0, sizeof(none));
			dfs(1, 0, V, 0);
		//}
		printf("%lld\n", ans);
		exit(0);
	}
}

namespace task1{
	ll dp[N][105][4], ans;
	//ll allcnt;
	void dfs(int u, int fa) {
		//cout<<"dfs "<<u<<' '<<fa<<endl;
		ll sum=0;
		bool leaf=1;
		for (int i=head[u]; i; i=e[i].next) {
			sum+=p[e[i].to];
			if (e[i].to!=fa) leaf=0;
		}
		memset(dp[u], 0, sizeof(ll)*420);
		dp[u][V][3]=p[fa];
		if (leaf) return ;
		for (int i=head[u],v; i; i=e[i].next) {
			v = e[i].to;
			if (v!=fa) {
				dfs(v, u);
				for (int s=V; s>=0; --s) {
					//++allcnt;
					//dp[u][s][0]=dp[u][s][1]=dp[u][s][2]=dp[u][s][3]=0;
					dp[u][s][0] = max(dp[u][s][0], max(dp[v][s+1][2], dp[v][s+1][3]));
					dp[u][s][1] = max(dp[u][s][1], max(dp[v][s][0], dp[v][s][1]));
					if (s>0) {
						dp[u][s][2] = max(dp[u][s][2], sum-p[v]+max(dp[v][s+1][2], dp[v][s+1][3]));
						dp[u][s][3] = max(max(dp[u][s][3], sum-p[v]+max(dp[v][s][0], dp[v][s][1])), sum);
					}
				}
			}
		}
	}
	void solve() {
		//cout<<double(sizeof(dp))/1024/1024<<endl;
		for (int i=1; i<=n; ++i) {
			//cout<<i<<endl;
			if (clock()>=1600000) {printf("%lld\n", ans); exit(0);}
			//memset(dp, 0, sizeof(dp));
			dfs(i, 0);
			for (int j=0; j<=V; ++j) ans=max(ans, max(max(dp[i][j][0], dp[i][j][1]), max(dp[i][j][2], dp[i][j][3])));
		}
		//int rt=2;
		//dfs(rt, 0);
		//for (int j=0; j<=V; ++j) ans=max(ans, max(max(dp[rt][j][0], dp[rt][j][1]), max(dp[rt][j][2], dp[rt][j][3])));
		#if 0
		for (int i=1; i<=n; ++i) {
			for (int j=0; j<=V; ++j) {
				for (int k=0; k<4; ++k) cout<<dp[i][j][k]<<' '; cout<<endl;
			}
		}
		#endif
		//for (int j=0; j<=V; ++j) {for (int k=0; k<4; ++k) cout<<dp[2][j][k]<<' '; cout<<endl;}
		printf("%lld\n", ans);
		//cout<<"allcnt: "<<allcnt<<endl;
		exit(0);
	}
}

namespace task2{
	ll f[N][105], g[N][105], ans;
	void dfs(int u, int fa) {
		//cout<<"dfs "<<u<<' '<<fa<<endl;
		ll sum=0;
		for (int i=head[u]; i; i=e[i].next) sum+=p[e[i].to];
		memset(f[u], 0, sizeof(ll)*105);
		f[u][1]=sum;
		for (int i=head[u],v; i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			memset(g[v], 0, sizeof(ll)*105);
			//for (int j=1; j<=V; ++j) {
			//	g[v][j] = max(g[v][j], max(g[u][j], sum-p[fa]+g[u][j-1]));
			//}
			dfs(v, u);
			for (int j=1; j<=V; ++j) {
				//f[u][j] = max(f[u][j], max(f[v][j], sum-p[v]+f[v][j-1]));
				g[u][j] = max(g[u][j], max(g[v][j], sum-p[fa]+g[v][j+1]));
			}
		}
	}
	void solve() {
		for (int i=1; i<=n; ++i) {
			if (clock()>=1600000) {printf("%lld\n", ans); exit(0);}
			memset(g[i], 0, sizeof(ll)*105);
			dfs(i, 0);
			for (int j=1; j<=n; ++j) for (int k=0; k<=V; ++k) ans=max(ans, max(f[j][k], g[j][k]));
			//for (int j=0; j<=V; ++j) cout<<g[3][j]<<' '; cout<<endl;
		}
		printf("%lld\n", ans);
		exit(0);
	}
}

namespace task{
	ll f[N][105], g[N][105], ans;
	void dfs(int u, int fa) {
		//cout<<"dfs "<<u<<' '<<fa<<endl;
		ll sum=0;
		for (int i=head[u]; i; i=e[i].next) sum+=p[e[i].to];
		f[u][1]=sum;
		ll maxn[105][4], maxi[105][4];
		memset(maxn, 0, sizeof(maxn));
		memset(maxi, 0, sizeof(maxi));
		for (int i=head[u],v; i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			for (int j=1; j<=V; ++j) {
				g[v][j] = max(g[v][j], max(g[u][j], sum-p[fa]+g[u][j-1]));
				if (g[v][j]>=maxn[j][0]) {maxn[j][1]=maxn[j][0]; maxi[j][1]=maxi[j][0]; maxn[j][0]=g[v][j]; maxi[j][0]=v;}
				else if (g[v][j]>maxn[j][1]) maxn[j][1]=g[v][j], maxi[j][1]=v;
			}
			dfs(v, u);
			for (int j=1; j<=V; ++j) {
				f[u][j] = max(f[u][j], max(f[v][j], sum-p[v]+f[v][j-1]));
				if (f[v][j]>=maxn[j][2]) {maxn[j][3]=maxn[j][2]; maxi[j][3]=maxi[j][2]; maxn[j][2]=f[v][j]; maxi[j][2]=v;}
				else if (f[v][j]>maxn[j][3]) maxn[j][3]=f[v][j], maxi[j][3]=v;
			}
		}
		cout<<"u: "<<u<<endl;
		for (int j=0; j<V; ++j) {
			if (maxi[j+1][0]!=maxi[j][2]) ans=max(ans, maxn[j+1][0]+maxn[j][2]), cout<<"try1: "<<maxn[j+1][0]<<' '<<maxi[j+1][0]<<' '<<maxn[j][2]<<' '<<maxi[j][2]<<' '<<maxn[j+1][0]+maxn[j][2]<<endl;
			else {
				if (maxi[j+1][1]!=maxi[j][2]) ans=max(ans, maxn[j+1][1]+maxn[j][2]), cout<<"try2: "<<maxn[j+1][1]<<' '<<maxi[j+1][1]<<' '<<maxn[j][2]<<' '<<maxi[j][2]<<' '<<maxn[j+1][1]+maxn[j][2]<<endl;
				if (maxi[j+1][0]!=maxi[j][3]) ans=max(ans, maxn[j+1][0]+maxn[j][3]), cout<<"try3: "<<maxn[j+1][0]<<' '<<maxn[j][3]<<' '<<maxn[j+1][0]+maxn[j][3]<<endl;
			}
			ans=max(ans, max(maxn[j][0], maxn[j][2]));
		}
	}
	void solve() {
		dfs(1, 0);
		printf("%lld\n", ans);
		exit(0);
	}
}

signed main()
{
	#ifdef DEBUG
	freopen("1.in", "r", stdin);
	#endif
	
	n=read(); V=read();
	for (int i=1; i<=n; ++i) p[i]=read();
	for (int i=1,u,v; i<n; ++i) {
		u=read(); v=read();
		add(u, v); add(v, u);
	}
	task2::solve();

	return 0;
}

然后考虑如何不枚举起点
那就需要换根DP了,令 \(f[i][j]\) 表示从以i为根的子树中走到 \(i\) ,$ g[i][j]$ 表示从i的父亲走到 \(i\) 及其子树中撒 \(j\) 次的最大收益
转移的时候要特别注意先后顺序
首先方程有了,ans=max(ans, f[u][j]+g[to][v-j])
而我们要在同一次遍历中更新 \(ans,f[u][j],g[u][j]\)
因为f和g肯定不能选来自同一棵子树的,所以f要用从之前遍历过的子树中的,所以先更新ans,再转移f,g
发现这样只是在用一个g匹配它左边的所有f,显然不够,所以还要逆序枚举一遍
挺有思维量的,做了巨久……还因为变量名重了没看出来陷入高度自闭

Code:

#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define ll long long 
//#define int long long 

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, v;
int head[N], size, sta[N], top; ll p[N], sum[N], ans, f[N][105], g[N][105];
struct edge{int to, next;}e[N<<1];
inline void add(int s, int t) {edge* k=&e[++size]; k->to=t; k->next=head[s]; head[s]=size;}

void dfs(int u, int fa) {
	for (int i=head[u],v; i; i=e[i].next) {
		v = e[i].to;
		sum[u]+=p[v];
		if (v!=fa) dfs(v, u);
	}
	int to;
	for (int i=1; i<=v; ++i) f[u][i]=sum[u], g[u][i]=sum[u]-p[fa];
	for (int i=head[u]; i; i=e[i].next) {
		to = e[i].to;
		if (to==fa) continue;
		sta[++top]=to;
		for (int j=1; j<=v; ++j) {
			ans = max(ans, f[u][j]+g[to][v-j]);
			f[u][j]=max(f[u][j], max(f[to][j], f[to][j-1]+sum[u]-p[to]));
			g[u][j]=max(g[u][j], max(g[to][j], g[to][j-1]+sum[u]-p[fa]));
		}
	}
	ans = max(ans, max(f[u][v], g[u][v]));
	for (int i=1; i<=v; ++i) f[u][i]=sum[u], g[u][i]=sum[u]-p[fa];
	while (top) {
		to=sta[top--];
		for (int j=1; j<=v; ++j) {
			ans = max(ans, f[u][j]+g[to][v-j]);
			f[u][j]=max(f[u][j], max(f[to][j], f[to][j-1]+sum[u]-p[to]));
			g[u][j]=max(g[u][j], max(g[to][j], g[to][j-1]+sum[u]-p[fa]));
		}
	}
	ans = max(ans, max(f[u][v], g[u][v]));
}

signed main()
{
	n=read(); v=read();
	for (int i=1; i<=n; ++i) p[i]=read();
	for (int i=1,u,v; i<n; ++i) {u=read(); v=read(); add(u, v); add(v, u);}
	dfs(1, 0);
	printf("%lld\n", ans);
	
	return 0;
}
posted @ 2021-08-08 15:56  Administrator-09  阅读(18)  评论(0编辑  收藏  举报