树上距离(树形DP)

问题 A: 树上距离

懒惰的温温今天上班也在偷懒。盯着窗外发呆的温温发现,透过窗户正巧能看到一棵n个节点的树。一棵n个节点的树包含n-1条边,且n个节点是联通的。树上两点之间的距离即两点之间的最短路径包含的边数。
突发奇想的温温想要知道,树上有多少个不同的点对,满足两点之间的距离恰好等于k。
注意:(u, v)和(v, u)视作同一个点对,只计算一次答案。

输入

第一行两个整数n和k。
接下来n-1行每行两个整数ai, bi,表示节点ai和bi之间存在一条边。
1 ≤ k ≤ 500
2 ≤ n ≤ 500 for 40%
2 ≤ n ≤ 50000 for 100%

输出

输出一个整数,表示满足条件的点对数量。

样例输入

[样例1]
5 2
1 2
2 3
3 4
2 5
[样例2]
5 3
1 2
2 3
3 4
4 5

样例输出

[样例1]
4
[样例2]
2

思路:

简单树形DP,\(dp[u][i]\)记录\(u\)的子树中到\(u\)距离为\(i\)的路径数。统计答案时直接计算经过\(u \to v\)且长度为\(k\)的路径条数即可。
转移很显然:
\(dp[u][i]=\sum_{fa[v]=u}dp[v][i-1]\)
时间复杂度\(O(nk)\).

#include "iostream"
#include "stdio.h"
#include "string.h"
#include "algorithm"
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
const int N=1e6+5;
const ll mod=998244353;
const double eps=1e-5;
//const double pi=acos(-1);
 
#define ls p<<1
#define rs p<<1|1
int dp[50005][505];
ll ans=0;
int n,k;
vector<int>g[N];
void dfs(int u,int fa)
{
    dp[u][0]=1;
    for(auto v:g[u])
    {
        if(v==fa) continue;
        dfs(v,u);
        for(int i=0;i<k;i++)
            ans+=1ll*dp[u][i]*dp[v][k-i-1];
//        for(int i=1;i<=k;i++) f[v][i]+=dp[u][i-1];
        for(int i=1;i<=k;i++) dp[u][i]+=dp[v][i-1];
    }
}
int main()
{
#ifndef ONLINE_JUDGE
    freopen("in.txt", "r", stdin);
#endif
    ios::sync_with_stdio(false);
    cin.tie(0);
    cin>>n>>k;
    for(int i=1;i<n;i++)
    {
        int u,v;
        cin>>u>>v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs(1,0);
//    for(int i=1;i<=n;i++) ans+=dp[i][k]+f[i][k];
    cout<<ans<<endl;
    return 0;
}
posted @ 2020-04-14 16:16  Suiyue_Li  阅读(397)  评论(0编辑  收藏  举报