树形背包[2/ 50] luogu [P1273]
前言
在笔者做出了上一道超级板题后,开始肝起了之前gm开的链接的题。而这次这道题,其实和某某苹果树真的很像,连dp的状态都几乎一毛一样。只是需要在时间上有亿些优化。(其实之前lh大巨佬在评讲的时候讲过思路的,只是菜鸡笔者当时没有打出来)
题目
看到这道题,笔者瞬间要素察觉:诶!树!诶!传输费用已知!诶!观看用户数最大!(逃)
本着认真负责踏实严谨的治学态度(再次 正经脸 不要脸):
我们分析一下: 如果我们选择给某一个用户传送信号,即在这个点和根的这条路径上的点都要传送到,所以我们就联系到上一篇blog做到的那种方法:(不要急着那么用,会T的)
状态转移方程:\(dp[i][j] = max(dp[i][j], dp[son][k] + dp[i][j - son] - way[i][son]\)
T掉的代码
于是笔者想到这,就很快的打出了代码
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <vector>
using namespace std;
const int maxn = 3005;
struct data{
int w, dis;
};
int n, m;
int money[maxn];
bool vis[maxn]; //因为是单向存边,这个东东好像没必要
int dp[maxn][maxn];
vector<data> way[maxn];
void dfs(int x){
vis[x] = 1;
if(x >= (n - m + 1)){
dp[x][1] = money[x];
// printf("%d\n", x);
return;
}
for(int i = 0; i < way[x].size(); i ++){
data flag2 = way[x][i];
if(!vis[flag2.w]){
dfs(flag2.w);
for(int j = n; j >= 0; j --){
for(int k = 0; k <= j; k ++){
dp[x][j] = max(dp[x][j], dp[flag2.w][k] + dp[x][j - k] - flag2.dis);
}
}
}
}
}
int main() {
memset(dp, -0x3f3f3f3f, sizeof(dp));
for(int i = 0; i < maxn; i ++){
dp[i][0] = 0;
}
scanf("%d %d", &n, &m);
for(int i = 1; i <= (n - m); i ++){
int k;
scanf("%d", &k);
for(int j = 1; j <= k; j ++){
data flag;
scanf("%d %d", &flag.w, &flag.dis);
way[i].push_back(flag);
}
}
for(int i = 1; i <= m; i ++){
scanf("%d", &money[(n - m) + i]);
}
dfs(1);
int ans = 0;
for(int i = n; i >= 0; i --){
if(dp[1][i] >= 0){
printf("%d", i);
return 0;
}
}
return 0;
}
但是这份逻辑上没有什么毛病的代码,在你谷上不开$O_2$50 pts, 开了 60 pts, OJ 上 63 pts
正解
T了后,笔者就把那些奇奇怪怪的不必要的东西去掉后,依旧没有改进,但是笔者盲猜是自己的DFS T掉了, 于是开始打量起了代码。突然灵光乍现想到了一点点优化,于是乎,就A了
先放代码吧
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <vector>
using namespace std;
const int maxn = 3005;
struct data{
int w, dis;
};
int n, m;
int money[maxn];
int sum[maxn];
int dp[maxn][maxn];
vector<data> way[maxn];
void findsum(int x){
for(int i = 0; i < way[x].size(); i ++){
findsum(way[x][i].w);
sum[x] += sum[way[x][i].w] + 1;
}
}
void dfs(int x){
if(x >= (n - m + 1)){
dp[x][1] = money[x];
return;
}
for(int i = 0; i < way[x].size(); i ++){
data flag2 = way[x][i];
dfs(flag2.w);
for(int j = sum[x]; j >= 0; j --){
for(int k = 0; k <= j; k ++){
dp[x][j] = max(dp[x][j], dp[flag2.w][k] + dp[x][j - k] - flag2.dis);
}
}
}
}
int main() {
memset(dp, -0x3f3f3f3f, sizeof(dp));
for(int i = 0; i < maxn; i ++){
dp[i][0] = 0;
}
scanf("%d %d", &n, &m);
for(int i = 1; i <= (n - m); i ++){
int k;
scanf("%d", &k);
for(int j = 1; j <= k; j ++){
data flag;
scanf("%d %d", &flag.w, &flag.dis);
way[i].push_back(flag);
}
}
findsum(1);
for(int i = 1; i <= m; i ++){
scanf("%d", &money[(n - m) + i]);
}
dfs(1);
int ans = 0;
for(int i = n; i >= 0; i --){
if(dp[1][i] >= 0){
printf("%d", i);
return 0;
}
}
return 0;
}
其实这当中唯一的优化只有一点点
for(int j = sum[x]; j >= 0; j --)
这个\(sum\)数值是什么东西呢,就是在这个节点下面还有多少个节点(再用一个dfs遍历就好了)
在一个节点\(x\)下,最多可以传送到的也就只有\(sum[x]\)个,所以这个优化是很好证
时间对比
夜空中最亮的星,请照亮我前行