集合
奇怪的树上背包,而且带了一些玄学。借这篇题解详细地总结一下树上背包的做法。
题面
给定一棵 \(n\) 个点的树,给定正整数 \(k\) , 在树上找出 \(k\) 个不同的点,设为\(A_1, A_2...A_k\),使得 \(\sum_{i=2}^{k}{dis(A_i,A_{i-1})}\) 最小,输出这个最小值。其中 \(dis(x, y)\) 表示树上 x, y 之间最短路的长度。
\(1\le k\le n\le 3000\)。
解法
考虑答案肯定会包含很多个点,而这些个点里肯定只有唯一的一个深度最小的点(这不是废话吗),而这个点要么是这条路径的起点或者终点(起点还是终点效果一样),要么是路径上的一个点(也就是起点和终点都在子树内,且跨越了它的儿子们),于是可以考虑对于这两种情况进行讨论。
然后就是一个比较基本的树形DP了。先是考虑对于每一个节点维护两个值,g和s。g表示该子树内画一条以根为一端点的路径的最小代价,s表示在子树内画一条通过根节点的路径的最小代价。然后发现更新的过程中会用到一个值,那就是每棵子树内从根到根的路径的最小代价。正常更新即可,就是一个树上背包。
然后TLE了几个小时。各种卡常小技巧似乎都没有起到什么作用,最后尝试照着题解的枚举顺序改了一下就过了。太玄学了,问老师也并木有问出来个所以然。
以此我得到一个教训。树上背包尽量从前往后更新,而不是从后往前枚举决策,这样可能可以降低常数。
顺便说一句我学会了书上背包复杂度的分析方法,摘录一下:
这样的复杂度相当于在以 \(x\) 为根的子树中枚举点对,且这两个点分别在不同儿子的子树中。因此,对于
任意点对 \(i,j\) ,当且仅当DFS到 \(lca(i,j)\) 时会被枚举到,因此时间复杂度综合为 \(O(N^2)\)。
代码:
#include<cstdio>
#include<cstring>
//#define zczc
using namespace std;
const int N=3020;
inline void read(int &wh){
wh=0;int f=1;char w=getchar();
while(w<'0'||w>'9'){if(w=='-')f=-1;w=getchar();}
while(w<='9'&&w>='0'){wh=wh*10+w-'0';w=getchar();}
wh*=f;return;
}
inline void check(int &s1,int s2){
if(s1>s2)s1=s2;return;
}
inline int min(int s1,int s2){
return s1<s2?s1:s2;
}
struct edge{
int t,v,next;
}e[N<<1];
int head[N],esum;
inline void add(int fr,int to,int val){
esum++;
e[esum].t=to;
e[esum].v=val;
e[esum].next=head[fr];
head[fr]=esum;
}
int m,n,ans=1e9;
int size[N],f[N][N],g[N][N],s[N][N];
void dfs(int wh,int fa){
size[wh]=1;f[wh][1]=g[wh][1]=s[wh][1]=0;
for(int i=head[wh],th,val;i;i=e[i].next){
th=e[i].t;val=e[i].v;
if(th==fa)continue;
dfs(th,wh);
for(int j=size[wh];j;j--){
for(int k=size[th];k;k--){
check(f[wh][j+k],f[wh][j]+f[th][k]+(val<<1));
check(g[wh][j+k],min(g[wh][j]+f[th][k]+val,f[wh][j]+g[th][k])+val);
check(s[wh][j+k],min(min(s[wh][j]+f[th][k],f[wh][j]+s[th][k])+val,g[wh][j]+g[th][k])+val);
}
}
size[wh]+=size[th];
}
check(ans,min(f[wh][n],min(g[wh][n],s[wh][n])));
return;
}
signed main(){
#ifdef zczc
freopen("in.txt","r",stdin);
#endif
memset(f,0x3f,sizeof(f));
memset(g,0x3f,sizeof(g));
memset(s,0x3f,sizeof(s));
int s1,s2,s3;
read(m);read(n);
for(int i=1;i<m;i++){
read(s1);read(s2);read(s3);
add(s1,s2,s3);
add(s2,s1,s3);
}
dfs(1,0);
printf("%d\n",ans);
return 0;
}