HDU 4616 Game 树形dp
题目链接:
http://acm.hdu.edu.cn/showproblem.php?pid=4616
Game
Memory Limit: 65535/32768 K (Java/Others)
输出
For each testcase, output the maximum total value of gifts you can get.
样例输入
2
3 1
23 0
12 0
123 1
0 2
2 1
3 2
23 0
12 0
123 1
0 2
2 1
样例输出
146
158
题意
给你一颗树,每个点有点权,同时一些点还有陷阱(到这个点获得价值后会掉到陷阱里),如果你掉到陷阱k次或无路可走会立马退出,问从任意一点出发,退出时能获得的最大价值,每个点只能走一次(不能走回头路)
题解
首先预处理出两个东西:
1、从某个叶子走到u,掉进陷阱j次的能获得的最大值,次大值。
2、从u走到某个叶子的,掉进陷阱j次能获得的最大值,次大值。
然后枚举每条边,把链拆成两块,考虑左右两块的所有组合。
#include<map>
#include<set>
#include<cmath>
#include<queue>
#include<stack>
#include<ctime>
#include<vector>
#include<cstdio>
#include<string>
#include<bitset>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<functional>
using namespace std;
#define X first
#define Y second
#define mkp make_pair
#define lson (o<<1)
#define rson ((o<<1)|1)
#define mid (l+(r-l)/2)
#define sz() size()
#define pb(v) push_back(v)
#define all(o) (o).begin(),(o).end()
#define clr(a,v) memset(a,v,sizeof(a))
#define bug(a) cout<<#a<<" = "<<a<<endl
#define rep(i,a,b) for(int i=a;i<(b);i++)
#define scf scanf
#define prf printf
typedef long long LL;
typedef vector<int> VI;
typedef pair<int,int> PII;
typedef vector<pair<int,int> > VPII;
const int INF=0x3f3f3f3f;
const LL INFL=0x3f3f3f3f3f3f3f3fLL;
const double eps=1e-8;
const double PI = acos(-1.0);
//start----------------------------------------------------------------------
const int maxn=50030;
int n,c;
LL val[maxn];
int tra[maxn];
struct Edge {
int v,ne;
Edge(int v,int ne):v(v),ne(ne) {}
Edge() {}
} egs[maxn*2];
int head[maxn],tot;
void addEdge(int u,int v) {
egs[tot]=Edge(v,head[u]);
head[u]=tot++;
}
///dp[u][j][0]表示在子树中以u为起点,机会为j次能够获得的最大值,dp[u][j][1]为对应的次大值,id[u]记录最大值的更新方向
///dp2[u][j][0]表示在子树中以u为终点,机会为j次能够获得的最大值,dp2[u][j][1]为对应的次大值,id2[u]记录最大值的更新方向
LL dp[maxn][5][2],dp2[maxn][5][2];
int id[maxn][5],id2[maxn][5];
void dfs(int u,int fa) {
if(tra[u]) dp[u][1][0]=dp2[u][1][0]=val[u];
else dp[u][0][0]=dp2[u][0][0]=val[u];
bool child=false;
for(int p=head[u]; p!=-1; p=egs[p].ne) {
Edge& e=egs[p];
if(e.v==fa) continue;
child=true;
dfs(e.v,u);
int v=e.v;
///以u为终点
for(int i=0; i<=c; i++) {
int t=tra[u];
if(dp[u][i+t][0]<dp[v][i][0]+val[u]) {
dp[u][i+t][1]=dp[u][i+t][0];
dp[u][i+t][0]=dp[v][i][0]+val[u];
id[u][i+t]=v;
} else if(dp[u][i+t][1]<dp[v][i][0]+val[u]) {
dp[u][i+t][1]=dp[v][i][0]+val[u];
}
}
///以u为起点,这和上面的区别体现在限制为1的初始化上
///如果当前点有限制,那么如果你限制为1,明显是会直接停下来的,
///不可能再由都没有经过限制的儿子那里转移过来。
if(tra[u]) {
dp2[u][1][0]=val[u];
for(int i=2; i<=c; i++) {
if(dp2[u][i][0]<dp2[v][i-1][0]+val[u]) {
dp2[u][i][1]=dp2[u][i][0];
dp2[u][i][0]=dp2[v][i-1][0]+val[u];
id2[u][i]=v;
} else if(dp2[u][i][1]<dp2[v][i-1][0]+val[u]) {
dp2[u][i][1]=dp2[v][i-1][0]+val[u];
}
}
} else {
for(int i=0; i<=c; i++) {
if(dp2[u][i][0]<dp2[v][i][0]+val[u]) {
dp2[u][i][1]=dp2[u][i][0];
dp2[u][i][0]=dp2[v][i][0]+val[u];
id2[u][i]=v;
} else if(dp2[u][i][1]<dp2[v][i][0]+val[u]) {
dp2[u][i][1]=dp2[v][i][0]+val[u];
}
}
}
}
}
///枚举每条边,吧链拆分成两部分,利用dp和dp2来更新答案
LL ans;
void dfs2(int u,int fa) {
for(int p=head[u]; p!=-1; p=egs[p].ne) {
Edge& e=egs[p];
if(e.v==fa) continue;
dfs2(e.v,u);
int v=e.v;
LL u1,u0;
for(int i=0; i<=c; i++) {
if(id[u][i]==v) u0=dp[u][i][1];
else u0=dp[u][i][0];
if(id2[u][i]==v) u1=dp2[u][i][1];
else u1=dp2[u][i][0];
for(int j=0; j+i<=c; j++) {
if(i<c) ans=max(ans,u0+dp2[e.v][j][0]);
if(j<c) ans=max(ans,u1+dp[e.v][j][0]);
if(i+j<c) ans=max(ans,u0+dp[e.v][j][0]);
}
}
}
}
void init() {
clr(head,-1);
clr(dp,0),clr(dp2,0);
clr(id,-1),clr(id2,-1);
tot=0;
}
int main() {
int tc;
scf("%d",&tc);
while(tc--) {
scf("%d%d",&n,&c);
init();
rep(i,0,n) scf("%lld%d",&val[i],&tra[i]);
rep(i,0,n-1) {
int u,v;
scf("%d%d",&u,&v);
addEdge(u,v);
addEdge(v,u);
}
dfs(0,-1);
ans=0;
dfs2(0,-1);
prf("%lld\n",ans);
}
return 0;
}
//end-----------------------------------------------------------------------
/*
1
3 1
23 1
12 0
123 1
0 2
2 1
*/
来个精炼版的:
#include<map>
#include<set>
#include<cmath>
#include<queue>
#include<stack>
#include<ctime>
#include<vector>
#include<cstdio>
#include<string>
#include<bitset>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<functional>
using namespace std;
#define X first
#define Y second
#define mkp make_pair
#define lson (o<<1)
#define rson ((o<<1)|1)
#define mid (l+(r-l)/2)
#define sz() size()
#define pb(v) push_back(v)
#define all(o) (o).begin(),(o).end()
#define clr(a,v) memset(a,v,sizeof(a))
#define bug(a) cout<<#a<<" = "<<a<<endl
#define rep(i,a,b) for(int i=a;i<(b);i++)
#define scf scanf
#define prf printf
typedef long long LL;
typedef vector<int> VI;
typedef pair<int,int> PII;
typedef vector<pair<int,int> > VPII;
const int INF=0x3f3f3f3f;
const LL INFL=0x3f3f3f3f3f3f3f3fLL;
const double eps=1e-8;
const double PI = acos(-1.0);
//start----------------------------------------------------------------------
const int maxn=50505;
int n,m;
LL gif[maxn];
int tra[maxn];
VI G[maxn];
///dp[u][j][0]表示以u为终点的,掉过j次陷阱的能获得的最大值
///dp[u][j][1]表示以u为起点的,掉过j次陷阱的能获得的最大值(注意,遇到第j个陷阱的时候回马上停下来,所以更新与上面的有所不同
LL dp[maxn][5][2];
LL ans;
void dfs(int u,int fa) {
clr(dp[u],0);
if(tra[u]) {
dp[u][1][0]=dp[u][1][1]=gif[u];
} else {
dp[u][0][0]=dp[u][0][1]=gif[u];
}
rep(i,0,G[u].sz()) {
int v=G[u][i];
if(v==fa) continue;
dfs(v,u);
///边搜边枚举
for(int j=0; j<=m; j++) {
for(int k=0; k+j<=m; k++) {
if(j<m) ans=max(ans,dp[u][j][0]+dp[v][k][1]);
if(k<m) ans=max(ans,dp[u][j][1]+dp[v][k][0]);
if(j+k<m) ans=max(ans,dp[u][j][0]+dp[v][k][0]);
}
}
for(int j=0; j<=m; j++) {
dp[u][j+tra[u]][0]=max(dp[u][j+tra[u]][0],dp[v][j][0]+gif[u]);
if(tra[u]&&j==0) dp[u][1][1]=gif[u];
else dp[u][j+tra[u]][1]=max(dp[u][j+tra[u]][1],dp[v][j][1]+gif[u]);
}
}
}
void init() {
for(int i=0; i<n; i++) G[i].clear();
}
int main() {
int tc;
scf("%d",&tc);
while(tc--) {
scf("%d%d",&n,&m);
init();
rep(i,0,n) scf("%lld%d",&gif[i],&tra[i]);
rep(i,0,n-1) {
int u,v;
scf("%d%d",&u,&v);
G[u].pb(v);
G[v].pb(u);
}
ans=0;
dfs(0,-1);
prf("%lld\n",ans);
}
return 0;
}
//end-----------------------------------------------------------------------
/*
2
3 1
23 1
12 1
123 0
0 2
2 1
*/