集合

奇怪的树上背包,而且带了一些玄学。借这篇题解详细地总结一下树上背包的做法。

题面

链接

给定一棵 \(n\) 个点的树,给定正整数 \(k\) , 在树上找出 \(k\) 个不同的点,设为\(A_1, A_2...A_k\),使得 \(\sum_{i=2}^{k}{dis(A_i,A_{i-1})}\) 最小,输出这个最小值。其中 \(dis(x, y)\) 表示树上 x, y 之间最短路的长度。

\(1\le k\le n\le 3000\)

解法

考虑答案肯定会包含很多个点,而这些个点里肯定只有唯一的一个深度最小的点(这不是废话吗),而这个点要么是这条路径的起点或者终点(起点还是终点效果一样),要么是路径上的一个点(也就是起点和终点都在子树内,且跨越了它的儿子们),于是可以考虑对于这两种情况进行讨论。

然后就是一个比较基本的树形DP了。先是考虑对于每一个节点维护两个值,g和s。g表示该子树内画一条以根为一端点的路径的最小代价,s表示在子树内画一条通过根节点的路径的最小代价。然后发现更新的过程中会用到一个值,那就是每棵子树内从根到根的路径的最小代价。正常更新即可,就是一个树上背包。

然后TLE了几个小时。各种卡常小技巧似乎都没有起到什么作用,最后尝试照着题解的枚举顺序改了一下就过了。太玄学了,问老师也并木有问出来个所以然。

以此我得到一个教训。树上背包尽量从前往后更新,而不是从后往前枚举决策,这样可能可以降低常数。

顺便说一句我学会了书上背包复杂度的分析方法,摘录一下:

这样的复杂度相当于在以 \(x\) 为根的子树中枚举点对,且这两个点分别在不同儿子的子树中。因此,对于
任意点对 \(i,j\) ,当且仅当DFS到 \(lca(i,j)\) 时会被枚举到,因此时间复杂度综合为 \(O(N^2)\)

代码:

#include<cstdio>
#include<cstring>
//#define zczc
using namespace std;
const int N=3020;
inline void read(int &wh){
    wh=0;int f=1;char w=getchar();
    while(w<'0'||w>'9'){if(w=='-')f=-1;w=getchar();}
    while(w<='9'&&w>='0'){wh=wh*10+w-'0';w=getchar();}
    wh*=f;return;
}
inline void check(int &s1,int s2){
	if(s1>s2)s1=s2;return;
}
inline int min(int s1,int s2){
	return s1<s2?s1:s2;
}

struct edge{
	int t,v,next;
}e[N<<1];
int head[N],esum;
inline void add(int fr,int to,int val){
	esum++;
	e[esum].t=to;
	e[esum].v=val;
	e[esum].next=head[fr];
	head[fr]=esum;
}
int m,n,ans=1e9;

int size[N],f[N][N],g[N][N],s[N][N];
void dfs(int wh,int fa){
	size[wh]=1;f[wh][1]=g[wh][1]=s[wh][1]=0;
	for(int i=head[wh],th,val;i;i=e[i].next){
		th=e[i].t;val=e[i].v;
		if(th==fa)continue;
		dfs(th,wh);
		for(int j=size[wh];j;j--){
			for(int k=size[th];k;k--){
				check(f[wh][j+k],f[wh][j]+f[th][k]+(val<<1));
				check(g[wh][j+k],min(g[wh][j]+f[th][k]+val,f[wh][j]+g[th][k])+val);
				check(s[wh][j+k],min(min(s[wh][j]+f[th][k],f[wh][j]+s[th][k])+val,g[wh][j]+g[th][k])+val);
			}
		}
		size[wh]+=size[th];
	}
	check(ans,min(f[wh][n],min(g[wh][n],s[wh][n])));
	return;
}

signed main(){
	
	#ifdef zczc
	freopen("in.txt","r",stdin);
	#endif
	
	memset(f,0x3f,sizeof(f));
	memset(g,0x3f,sizeof(g));
	memset(s,0x3f,sizeof(s));
	int s1,s2,s3;
	read(m);read(n);
	for(int i=1;i<m;i++){
		read(s1);read(s2);read(s3);
		add(s1,s2,s3);
		add(s2,s1,s3);
	}
	dfs(1,0);
	printf("%d\n",ans);
	
	return 0;
}
posted @ 2021-11-07 12:52  Feyn618  阅读(33)  评论(0编辑  收藏  举报