[ioi2005]river
题目描述
几乎整个Byteland王国都被森林和河流所覆盖。小点的河汇聚到一起,形成了稍大点的河。就这样,所有的河水都汇聚并流进了一条大河,最后这条大河流进了大海。这条大河的入海口处有一个村庄——名叫Bytetown 在Byteland国,有n个伐木的村庄,这些村庄都座落在河边。目前在Bytetown,有一个巨大的伐木场,它处理着全国砍下的所有木料。木料被砍下后,顺着河流而被运到Bytetown的伐木场。Byteland的国王决定,为了减少运输木料的费用,再额外地建造k个伐木场。这k个伐木场将被建在其他村庄里。这些伐木场建造后,木料就不用都被送到Bytetown了,它们可以在 运输过程中第一个碰到的新伐木场被处理。显然,如果伐木场座落的那个村子就不用再付运送木料的费用了。它们可以直接被本村的伐木场处理。 注意:所有的河流都不会分叉,也就是说,每一个村子,顺流而下都只有一条路——到bytetown。 国王的大臣计算出了每个村子每年要产多少木料,你的任务是决定在哪些村子建设伐木场能获得最小的运费。其中运费的计算方法为:每一块木料每千米1分钱。 编一个程序: 1.从文件读入村子的个数,另外要建设的伐木场的数目,每年每个村子产的木料的块数以及河流的描述。 2.计算最小的运费并输出。
输入格式
第一行 包括两个数 n(2<=n<=100),k(1<=k<=50,且 k<=n)。n为村庄数,k为要建的伐木场的数目。除了bytetown外,每个村子依次被命名为1,2,3……n,bytetown被命名为0。 接下来n行,每行包涵3个整数 wi——每年i村子产的木料的块数 (0<=wi<=10000) vi——离i村子下游最近的村子(或bytetown)(0<=vi<=n) di——vi到i的距离(km)。(1<=di<=10000) 保证每年所有的木料流到bytetown的运费不超过2000,000,000分 50%的数据中n不超过20。
输出格式
输出最小花费,精确到分。
你只要记住,如果题目中给你的是静态的数(没有变动的数据),那么这棵树上一般都有最优值的传递性(至少我还没看到过反例)。所以我们可以用动态规划来解这道题。
很容易从题中提取出两个信息:一是树中的点存在选择代价,为1;二是我们可以得出选择节点的价值。而数据范围是允许我们用N^3的算法过的,所以我们可以考虑用树上背包求解。
根据经验和对问题的分析,我们很容易设计出一种状态:dp(i,j,1/0)表示以i为根的子树中建j个伐木场,0表示i处不建伐木场,1表示建。然后我们考虑怎么转移。
显然要转移的话我们首先要计算出代价和价值。代价不需要计算可以得到,而价值却不怎么好算。原因是:不同的i的祖先节点中的伐木场建造方法会造成不同的价值。但我们可以发现i的价值只与最近的建了伐木场的祖先有关。所以设这个祖先为anc,设dis[x]表示x到根节点的距离,则我们可以表示出这个价值:
\[Value=\left\{\begin{array}{rcl}
0 & 如果i建伐木场\\
w[i]*(dis[i]-dis[anc]) & 如果i不建伐木场
\end{array}\right.
\]
所以我们可以在循环上再加一维,枚举i的所有祖先作为最近的建造伐木场的祖先,并且把状态也加一维:dp(i,k,j,0/1),其中k表示最近的建造了伐木场的祖先,其余不变。
然后按照正常的树上背包问题解法一样做就可以了,求出最小值作为答案,目标状态为min(dp(0,0,k,0),dp(0,0,k,1))。
时间复杂度为O(N^2 * K^2)。
* 为了枚举祖先,可以用栈来存下遍历到的点,当回溯时将节点出栈,这样的话每次栈中节点都是当前节点的祖先。
* 一开始写完代码时觉得代码好丑,参考了网上的一些代码之后决定把dp的0/1状态合并起来,这样代码看着顺眼得多。其实就是码风的小细节,无需过多在意。
* 不开long long见祖宗。
#include<iostream>
#include<cstring>
#include<cstdio>
#define maxn 111
using namespace std;
struct edge{
int to,dis,next;
edge(){}
edge(const int &_to,const int &_dis,const int &_next){
to=_to,dis=_dis,next=_next;
}
}e[maxn];
int head[maxn],k;
long long dp[maxn][maxn][maxn][2];
int stack[maxn],top;
int n,m,val[maxn],dis[maxn];
inline int read(){
register int x(0),f(1); register char c(getchar());
while(c<'0'||'9'<c){ if(c=='-') f=-1; c=getchar(); }
while('0'<=c&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=getchar();
return x*f;
}
inline void add(const int &u,const int &v,const int &w){ e[k]=edge(v,w,head[u]),head[u]=k++; }
void dfs(int u){
stack[++top]=u;
for(register int i=head[u];~i;i=e[i].next){
int v=e[i].to;
dis[v]=dis[u]+e[i].dis;
dfs(v);
for(register int i=1;i<=top;i++){
for(register int j=m;j>=0;j--){
//初始化
dp[u][stack[i]][j][0]+=dp[v][stack[i]][0][0];
dp[u][stack[i]][j][1]+=dp[v][u][0][0];
//全部都加0状态的原因是之前把0/1的状态和并了
for(register int k=0;k<=j;k++){
dp[u][stack[i]][j][0]=min(dp[u][stack[i]][j][0],dp[u][stack[i]][j-k][0]+dp[v][stack[i]][k][0]);
dp[u][stack[i]][j][1]=min(dp[u][stack[i]][j][1],dp[u][stack[i]][j-k][1]+dp[v][u][k][0]);
}
}
}
}
//把0/1的状态合并起来
for(register int i=1;i<=top;i++){
for(register int j=0;j<=m;j++){
if(j>0) dp[u][stack[i]][j][0]=min(dp[u][stack[i]][j][0]+val[u]*(dis[u]-dis[stack[i]]),dp[u][stack[i]][j-1][1]);
else dp[u][stack[i]][j][0]+=val[u]*(dis[u]-dis[stack[i]]);
}
}
top--;
}
int main(){
memset(head,-1,sizeof head);
n=read(),m=read();
for(register int i=1;i<=n;i++){
val[i]=read();
int v=read(),w=read();
add(v,i,w);
}
dfs(0);
printf("%lld\n",dp[0][0][m][0]);
return 0;
}