161D - Distance in Tree
链接
https://codeforces.com/problemset/problem/161/D
题目
思路
点分治的板子。但是得改改。 改的地方就是增加一个桶,然后相和为k的两个数量乘一下。 主要还是理解点分治代码为主代码
#define _CRT_SECURE_NO_WARNINGS
#include<iostream>
#include<vector>
#include<algorithm>
#include<math.h>
#include<sstream>
#include<string>
#include<string.h>
#include<iomanip>
#include<stdlib.h>
#include<map>
#include<queue>
#include<limits.h>
#include<climits>
#include<fstream>
#include<stack>
#define IOS ios::sync_with_stdio(false), cin.tie(0) ,cout.tie(0)
using namespace std;
#define int long long
const int N = 5e4 + 10;
int n, k, root, cnte, cntd, ans;
//cnte:edge的cnt计数器
int mxp[N], vis[N], sz[N], hd[N], dis[N];
//mxp:子树最大节点数量,判断重心,sz:总结点;hd:链式前向星
struct Edge { int to, next,val; }edge[N<<1];
void addedge(int u, int v, int w)
{
cnte++;
edge[cnte].to = v;
edge[cnte].val = w;
edge[cnte].next = hd[u];
hd[u] = cnte;
}
void getroot(int u, int father, int n_part)//n_part与总点数n区分,
{
sz[u] = 1;//sz:数量,子树的节点数量
mxp[u] = 0;
for (int i = hd[u]; i; i = edge[i].next)
{
int v = edge[i].to;
if (v == father or vis[v])continue;
getroot(v, u, n_part);
sz[u] += sz[v];
mxp[u] = max(mxp[u], sz[v]);
}
mxp[u] = max(mxp[u], n_part - sz[u]);
if (mxp[u] < mxp[root])root = u;//重心代码
}
void getdis(int u, int d, int father)//逐个判断子树上的节点到目前根节点的距离(根节点是树的重心)
{
dis[++cntd] = d;
for (int i = hd[u]; i; i = edge[i].next)
{
int v = edge[i].to; int w = edge[i].val;
if (v == father or vis[v])continue;
getdis(v, d + w, u);
}
}
int calc(int u, int d)
{
cntd = 0;
getdis(u, d, 0);
map<int, int>tms;
for (int i = 1; i <= cntd; i++)tms[dis[i]]++;
int sum = 0;
for (map<int, int>::iterator it = tms.begin(); it != tms.end(); ++it)
{
if (it->first * 2 != k)
sum += tms[it->first] * tms[k - it->first];
else sum += tms[it->first] * (tms[it->first] - 1) / 2;
tms[it->first] = 0;
}
return sum;
}
void solve(int u)//solve给的是根节点
{
ans += calc(u, 0); //
vis[u] = 1;
for (int i = hd[u]; i; i = edge[i].next)
{
int v = edge[i].to, w = edge[i].val;
if (vis[v])continue;
ans -= calc(v, w);//去重,去掉同一边树中的重复路径,如图一
root = 0;
mxp[0] = 1e9;
getroot(v, 0, sz[v]);
solve(root);
}
}
signed main()
{
IOS;
cin >> n;cin >> k;
for (int i = 1; i < n; i++)
{
int u, v; cin >> u >> v;
addedge(u, v, 1); addedge(v, u, 1);
}
root = 0;
mxp[0] = 1e9;
getroot(1, 0, n);
solve(root);
cout << ans;
return 0;
}