719D(树形dp)

题目链接:http://codeforces.com/contest/791/problem/D

 

题意:给出一棵树,每两个点之间的距离为1,一步最多可以走距离 k,问要将任意两个点之间的路径都走一遍,最少需要走多少步;

 

思路:对于不是很简单的问题我们可以将问题分解成若干步或许会简单一点,对于本题我们可以先考虑只求所有路径的距离之和, 假设我们求得其值为ans;不过因为有些路径的长度并不是m的整倍数,所以我们不能直接用ans/m得到答案;不过如果我们能找到所有不是m的整倍数的路径并且可以求出其%m的值的话,那么我们可以给其补上一个尽量小的数并使其能整除m, 那么我们只要将所有补加的数累加到ans里面去,那么ans/m就是我们要的答案啦。。。

接下来我们需要考虑一下具体如何实现上面两步:

对于如何求得ans,我们可以计算出对于每条边经过他的路径数,就是这条边对ans的贡献值,那么累加对于每条边经过她的路径数就是ans啦;每条边都是由相邻的两个节点构成的,我们可以将两个节点看做父子节点,那么经过这条边的路径的数目为son*(n-son),其中son为子节点所在的子树的大小,那么n-son就是子树外的节点数目(这个应该挺好理解的,不理解的画下图就明白了);显然我们只要dfs搜一遍就能得到ans啦;

 

下面我们只要求出长度%m不为整数的路径就ok了。我们不防这样想,从某点出发的所有路径中任选两条可以组成一条经过该点的路径,那么所有组合即为所有经过该点的路径。

我们用dp[i][j]记录从点 i 出发%m为 j 的路径的数目,那么我们可以同过 j 的组合得到经过点 i 长度%m=j'的路径数目,显然只要求出dp[i][j]我们很容易得到补加的的值是多少。

若对于当前节点i, 我们已知dp[i][j],那么显然对于其父节点有 dp[i'][(j+1)] = dp[i][j],所以我们我们可以在dfs回溯时通过dp计算出dp[i][j]的值;对叶子节点初始化为dp[i][0]=1;

至此已经圆满解决这个问题啦。。

 

代码:

 1 #include <iostream>
 2 #include <stdio.h>
 3 #include <vector>
 4 #define ll long long
 5 using namespace std;
 6 
 7 const int MAXN=2e5+10;
 8 int n, m;
 9 bool vis[MAXN];//标记当前节点是否搜过
10 vector<int> mp[MAXN];
11 ll dp[MAXN][6], son[MAXN], ans=0;//dp[i][j]存储以i为根节点,i的子树中距离i长度mod k==j的的路径的条数,son[i]记录i的子树大小
12 
13 void dfs(int point){
14     son[point]=1;//相当于将两个数组初始化为 1
15     dp[point][0]=1;
16     for(int i=0; i<mp[point].size(); i++){
17         int v=mp[point][i];
18         if(!vis[v]){
19             vis[v]=true;
20             dfs(v);
21             son[point]+=son[v];//将子树中节点的数目加到当前节点上
22             ans+=(son[v])*(ll)(n-son[v]);//统计经过边[point,i]的路径数目
23             for(int j=0; j<m; j++){
24                 for(int k=0; k<m; k++){
25                     if((j+k+1)%m){
26                         ans+=dp[point][j]*dp[v][k]*(ll)(m-(j+k+1)%m);//i+j+k为分别由点point,及一个其子节点引出的路径长度%m再求和
27                     }
28                 }
29             }
30             for(int j=0; j<m; j++){
31                 dp[point][(j+1)%m]+=dp[v][j];//从节点mp[point][i]回溯到其父节点,那么由原来mp[point][i]的子树到其的距离%m=j的路径数目转移得到point的子树到其距离+1%m=j的路径数目
32             }
33         }
34     }
35 }
36 
37 int main(void){
38     scanf("%d%d", &n, &m);
39     for(int i=1; i<n; i++){
40         int x, y;
41         scanf("%d%d", &x, &y);
42         mp[x].push_back(y);
43         mp[y].push_back(x);
44     }
45     vis[1]=true;
46     dfs(1);
47     printf("%lld\n", ans/m);
48     return 0;
49 }
View Code

 

posted @ 2017-03-20 20:46  geloutingyu  阅读(212)  评论(0编辑  收藏  举报