树上距离(树形DP)
问题 A: 树上距离
懒惰的温温今天上班也在偷懒。盯着窗外发呆的温温发现,透过窗户正巧能看到一棵n个节点的树。一棵n个节点的树包含n-1条边,且n个节点是联通的。树上两点之间的距离即两点之间的最短路径包含的边数。
突发奇想的温温想要知道,树上有多少个不同的点对,满足两点之间的距离恰好等于k。
注意:(u, v)和(v, u)视作同一个点对,只计算一次答案。
输入
第一行两个整数n和k。
接下来n-1行每行两个整数ai, bi,表示节点ai和bi之间存在一条边。
1 ≤ k ≤ 500
2 ≤ n ≤ 500 for 40%
2 ≤ n ≤ 50000 for 100%
输出
输出一个整数,表示满足条件的点对数量。
样例输入
[样例1]
5 2
1 2
2 3
3 4
2 5
[样例2]
5 3
1 2
2 3
3 4
4 5
样例输出
[样例1]
4
[样例2]
2
思路:
简单树形DP,\(dp[u][i]\)记录\(u\)的子树中到\(u\)距离为\(i\)的路径数。统计答案时直接计算经过\(u \to v\)且长度为\(k\)的路径条数即可。
转移很显然:
\(dp[u][i]=\sum_{fa[v]=u}dp[v][i-1]\)
时间复杂度\(O(nk)\).
#include "iostream"
#include "stdio.h"
#include "string.h"
#include "algorithm"
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
const int N=1e6+5;
const ll mod=998244353;
const double eps=1e-5;
//const double pi=acos(-1);
#define ls p<<1
#define rs p<<1|1
int dp[50005][505];
ll ans=0;
int n,k;
vector<int>g[N];
void dfs(int u,int fa)
{
dp[u][0]=1;
for(auto v:g[u])
{
if(v==fa) continue;
dfs(v,u);
for(int i=0;i<k;i++)
ans+=1ll*dp[u][i]*dp[v][k-i-1];
// for(int i=1;i<=k;i++) f[v][i]+=dp[u][i-1];
for(int i=1;i<=k;i++) dp[u][i]+=dp[v][i-1];
}
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("in.txt", "r", stdin);
#endif
ios::sync_with_stdio(false);
cin.tie(0);
cin>>n>>k;
for(int i=1;i<n;i++)
{
int u,v;
cin>>u>>v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1,0);
// for(int i=1;i<=n;i++) ans+=dp[i][k]+f[i][k];
cout<<ans<<endl;
return 0;
}