【树形DP】codeforces K. Send the Fool Further! (medium)
http://codeforces.com/contest/802/problem/K
【题意】
给定一棵树,Heidi从根结点0出发沿着边走,每个结点最多经过k次,求这棵树的最大花费是多少(同一条边走n次花费只算一次)
【思路】
对于结点v:
- 如果在v的某棵子树停下,那么可以“遍历”k棵子树(有的话)
- 如果还要沿着v返回v的父节点p,那么只能“遍历”k-1棵子树(有的话)。
用dp[v][1]表示第一种情况,dp[v][0]表示第二种情况;最后要求的就是dp[0][0]。
1. 对于dp[v][1],把所有的子树从大到小排序
(t=k-1)
2. 对于dp[v][0],枚举子结点dp[u][0]中的u,剩下的k-1个dp[u][1]取最大的,所以我们可以这样预处理:
sum=
(t=k)
- 如果u<k,则target=sum-dp[u][1]+dp[u][0]
- 否则, target=sum-dp[t][1]+dp[u][0](t是从大到小排序后的第k-1个)
这样,dp[0][0]就是所求结果(dp[0][0]一定大于dp[0][1]),时间复杂度是O(nlogn)
【官方题解】
【Accepted】
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
1 #include<iostream> 2 #include<cstdio> 3 #include<cstring> 4 #include<string> 5 #include<cmath> 6 #include<vector> 7 #include<algorithm> 8 9 using namespace std; 10 int n,m; 11 vector< vector< pair<int,int> > > g; 12 const int maxn=1e5+5; 13 int dp[maxn][2]; 14 void dfs(int v,int p,int edge) 15 { 16 //从p到v的花费要算在v里 17 dp[v][0]+=edge; 18 dp[v][1]+=edge; 19 vector< pair<int,int> > s; 20 //只有根结点没有父节点,非根结点有父节点,减去1 21 if(v==0) 22 { 23 s.resize(g[v].size()); 24 } 25 else 26 { 27 s.resize(g[v].size()-1); 28 } 29 //遍历 30 int num=0; 31 for(int i=0;i<g[v].size();i++) 32 { 33 int to=g[v][i].first; 34 if(to==p) 35 { 36 continue; 37 } 38 dfs(to,v,g[v][i].second); 39 s[num++]={dp[to][1],to}; 40 } 41 //从大到小排序 42 sort(s.begin(),s.end()); 43 reverse(s.begin(),s.end()); 44 //要记录各个子结点的rank,后面dp[v][0]枚举u是要分类 45 int pos[maxn]; 46 for(int i=0;i<s.size();i++) 47 { 48 pos[s[i].second]=i; 49 } 50 //计算dp[v][1] 51 for(int i=0;i<min(m-1,(int)s.size());i++) 52 { 53 dp[v][1]+=s[i].first; 54 } 55 //计算dp[v][0] 56 int sum=0; 57 for(int i=0;i<min(m,(int)s.size());i++) 58 { 59 sum+=s[i].first; 60 } 61 int maxu=-1; 62 //枚举 63 for(int i=0;i<g[v].size();i++) 64 { 65 int to=g[v][i].first; 66 if(to==p) 67 { 68 continue; 69 } 70 if(pos[to]<m) 71 { 72 maxu=max(maxu,sum-dp[to][1]+dp[to][0]); 73 } 74 else 75 { 76 maxu=max(maxu,sum-s[m-1].first+dp[to][0]); 77 } 78 } 79 if(maxu>-1) 80 { 81 dp[v][0]+=maxu; 82 } 83 } 84 int main() 85 { 86 while(~scanf("%d%d",&n,&m)) 87 { 88 memset(dp,0,sizeof(dp)); 89 g.resize(n); 90 int u,v,c; 91 for(int i=0;i<n-1;i++) 92 { 93 scanf("%d%d%d",&u,&v,&c); 94 g[u].push_back({v,c}); 95 g[v].push_back({u,c}); 96 } 97 //根结点为0,无父结点,根结点到父结点的花费也为0 98 dfs(0,0,0); 99 printf("%d\n",dp[0][0]); 100 } 101 return 0; 102 }
注意vector开始要resize.....orz
【WA】
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
1 #include<iostream> 2 #include<cstdio> 3 #include<cstring> 4 #include<string> 5 #include<algorithm> 6 #include<cmath> 7 8 using namespace std; 9 int n,k; 10 const int maxn=2e5+3; 11 struct edge 12 { 13 int to; 14 int nxt; 15 int c; 16 }e[maxn]; 17 int head[maxn]; 18 int tot; 19 struct node 20 { 21 int x; 22 int id; 23 }sz[maxn]; 24 int rk[maxn]; 25 bool cmp(node a,node b) 26 { 27 return a.x>b.x; 28 } 29 void init() 30 { 31 memset(head,-1,sizeof(head)); 32 tot=0; 33 } 34 35 void add(int u,int v,int c) 36 { 37 e[tot].to=v; 38 e[tot].c=c; 39 e[tot].nxt=head[u]; 40 head[u]=tot++; 41 } 42 int dp[maxn][2]; 43 44 int dfs(int u,int pa,int c) 45 { 46 dp[u][1]=c; 47 dp[u][0]=c; 48 int cnt=0; 49 for(int i=head[u];i!=-1;i=e[i].nxt) 50 { 51 int v=e[i].to; 52 int c=e[i].c; 53 if(v==pa) continue; 54 dfs(v,u,c); 55 sz[cnt].x=dp[v][1]; 56 sz[cnt++].id=v; 57 } 58 sort(sz,sz+cnt,cmp); 59 for(int i=0;i<min(cnt,k-1);i++) 60 { 61 dp[u][1]+=sz[i].x; 62 } 63 int sum=0; 64 for(int i=0;i<min(cnt,k);i++) 65 { 66 sum+=sz[i].x; 67 } 68 int ans=0; 69 for(int i=0;i<cnt;i++) 70 { 71 if(i<k) 72 { 73 ans=max(ans,sum-sz[i].x+dp[sz[i].id][0]); 74 } 75 else 76 { 77 ans=max(ans,sum-sz[k-1].x+dp[sz[i].id][0]); 78 } 79 } 80 dp[u][0]+=ans; 81 } 82 int main() 83 { 84 while(~scanf("%d%d",&n,&k)) 85 { 86 init(); 87 memset(dp,0,sizeof(dp)); 88 for(int i=0;i<n-1;i++) 89 { 90 int u,v,c; 91 scanf("%d%d%d",&u,&v,&c); 92 add(u,v,c); 93 add(v,u,c); 94 } 95 dfs(0,-1,0); 96 cout<<dp[0][0]<<endl; 97 } 98 return 0; 99 }
终于弄清楚了这个为什么WA!因为我在dfs里用了一个全局变量sz来保存{dp[v][1],v}。然而这是一个全局变量,所以一层里的正确值会被另一层修改!比如当我递归到0时已经有了正确值sz[0].w=5,sz[0].v=2;然而再递归到0的另一分枝1的时候,会修改sz[0],最后再回溯到0时sz[0]已经不是当年的sz[0]了!
所以还是用vector临时申请吧!
【AC(一个更优美的代码)】
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
1 #include<iostream> 2 #include<cstdio> 3 #include<cstring> 4 #include<string> 5 #include<algorithm> 6 #include<cmath> 7 8 using namespace std; 9 int n,k; 10 const int maxn=2e5+3; 11 struct edge 12 { 13 int to; 14 int nxt; 15 int c; 16 }e[maxn]; 17 int head[maxn]; 18 int tot; 19 int dp[maxn][2]; 20 21 struct node 22 { 23 int x; 24 int id; 25 node(){} 26 node(int _x,int _id):x(_x),id(_id){} 27 bool operator<(const node & nd) const 28 { 29 return x>nd.x; 30 } 31 }; 32 33 void init() 34 { 35 memset(head,-1,sizeof(head)); 36 tot=0; 37 } 38 39 void add(int u,int v,int c) 40 { 41 e[tot].to=v; 42 e[tot].c=c; 43 e[tot].nxt=head[u]; 44 head[u]=tot++; 45 } 46 47 int dfs(int u,int pa,int c) 48 { 49 dp[u][1]=c; 50 dp[u][0]=c; 51 vector<node> s; 52 for(int i=head[u];i!=-1;i=e[i].nxt) 53 { 54 int v=e[i].to; 55 int c=e[i].c; 56 if(v==pa) continue; 57 dfs(v,u,c); 58 s.push_back(node(dp[v][1],v)); 59 } 60 sort(s.begin(),s.end()); 61 int sz=s.size(); 62 for(int i=0;i<min(sz,k-1);i++) 63 { 64 dp[u][1]+=s[i].x; 65 } 66 int sum=0; 67 for(int i=0;i<min(sz,k);i++) 68 { 69 sum+=s[i].x; 70 } 71 int ans=0; 72 for(int i=0;i<sz;i++) 73 { 74 if(i<k) 75 { 76 ans=max(ans,sum-s[i].x+dp[s[i].id][0]); 77 } 78 else 79 { 80 ans=max(ans,sum-s[k-1].x+dp[s[i].id][0]); 81 } 82 } 83 dp[u][0]+=ans; 84 } 85 int main() 86 { 87 while(~scanf("%d%d",&n,&k)) 88 { 89 init(); 90 memset(dp,0,sizeof(dp)); 91 for(int i=0;i<n-1;i++) 92 { 93 int u,v,c; 94 scanf("%d%d%d",&u,&v,&c); 95 add(u,v,c); 96 add(v,u,c); 97 } 98 dfs(0,-1,0); 99 cout<<dp[0][0]<<endl; 100 } 101 return 0; 102 }
如果是vector<pair<int,int>> 要从大到小排序,可以先sort(s.begin(),s.end()),再reverse(s.begin(),s.end())