252. 树

题目链接

252. 树

给定一个有 \(N\) 个点(编号 \(0,1,…,N-1\))的树,每条边都有一个权值(不超过 \(1000\))。

树上两个节点 \(x\)\(y\) 之间的路径长度就是路径上各条边的权值之和。

求长度不超过 \(K\) 的路径有多少条。

输入格式

输入包含多组测试用例。

每组测试用例的第一行包含两个整数 \(N\)\(K\)

接下来 \(N-1\) 行,每行包含三个整数 \(u,v,l\),表示节点 \(u\)\(v\) 之间存在一条边,且边的权值为 \(l\)

当输入用例 \(N=0,K=0\) 时,表示输入终止,且该用例无需处理。

输出格式

每个测试用例输出一个结果。

每个结果占一行。

数据范围

\(1 \le N \le 10^4\),
\(1 \le K \le 5 \times 10^6\),
\(0 \le l \le 10^3\)

输入样例:

5 4
0 1 3
0 2 1
0 3 2
2 4 1
0 0

输出样例:

8

解题思路

点分治

点分治主要利用树的重心的一条性质:所有子树大小不超过 \(n/2\),利用该性质,\(dfs\) 的递归深度为 \(O(logn)\),利用重心点分治操作分为 \(递归+归并\)

本题利用点分治求解,先找到重心 \(x\),分为三种情况:

  1. 路径在子树中,递归处理
  2. 路径端点包含重心,求出所有这样的路径,找出满足要求的点
  3. 路径横跨子树,先找出所有的端点包含重心的路径,然后再组合,但由于组合的可能是同一棵子树的路径,需要容斥减去这部分
  • 时间复杂度:\(O(nlog^2n)\)

代码

// Problem: 树
// Contest: AcWing
// URL: https://www.acwing.com/problem/content/description/254/
// Memory Limit: 10 MB
// Time Limit: 3000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

// %%%Skyqwq
#include <bits/stdc++.h>
 
//#define int long long
#define help {cin.tie(NULL); cout.tie(NULL);}
#define pb push_back
#define fi first
#define se second
#define mkp make_pair
using namespace std;
 
typedef long long LL;
typedef pair<int, int> PII;
typedef pair<LL, LL> PLL;
 
template <typename T> bool chkMax(T &x, T y) { return (y > x) ? x = y, 1 : 0; }
template <typename T> bool chkMin(T &x, T y) { return (y < x) ? x = y, 1 : 0; }
 
template <typename T> void inline read(T &x) {
    int f = 1; x = 0; char s = getchar();
    while (s < '0' || s > '9') { if (s == '-') f = -1; s = getchar(); }
    while (s <= '9' && s >= '0') x = x * 10 + (s ^ 48), s = getchar();
    x *= f;
}

const int N=1e4+5;
int n,k,res;
vector<PII> adj[N];
bool st[N];
int p[N],q[N];
int get_wc(int x,int fa,int sum,int &wc)
{
	if(st[x])return 0;
	int max_part=0;
	int sz=1;
	for(auto t:adj[x])
	{
		int y=t.fi;
		if(y==fa)continue;
		int y_sz=get_wc(y,x,sum,wc);
		sz+=y_sz;
		max_part=max(max_part,y_sz);
	}
	max_part=max(max_part,sum-sz);
	if(max_part<res)res=max_part,wc=x;
	return sz;
}
int get_sz(int x,int fa)
{
	if(st[x])return 0;
	int res=1;
	for(auto t:adj[x])
	{
		int y=t.fi;
		if(y!=fa)res+=get_sz(y,x);
	}
	return res;
}
void get_dist(int x,int fa,int dist,int &qt)
{
	if(st[x])return ;
	q[++qt]=dist;
	for(auto t:adj[x])
	{
		int y=t.fi,w=t.se;
		if(y==fa)continue;
		get_dist(y,x,dist+w,qt);
	}
}
int get(int a[],int n)
{
	int res=0;
	sort(a+1,a+1+n);
	for(int i=n,j=0;i>=0;i--)
	{
		while(j+1<i&&a[j+1]+a[i]<=k)j++;
		if(j>=i)j=max(0,i-1);
		res+=j;
	}
	return res;
}
int cal(int x)
{
	if(st[x])return 0;
	int ans=0;
	res=0x3f3f3f3f;
	get_wc(x,0,get_sz(x,0),x);
	st[x]=true;
	int pt=0;
	for(auto t:adj[x])
	{
		int y=t.fi,w=t.se;
		int qt=0;
		get_dist(y,x,w,qt);
		for(int i=1;i<=qt;i++)
			if(q[i]<=k)ans++,p[++pt]=q[i];
		ans-=get(q,qt);
	}
	ans+=get(p,pt);
	for(auto t:adj[x])ans+=cal(t.fi);
	return ans;
}
int main()
{
    while(cin>>n>>k,n||k)
    {
    	memset(st,0,sizeof st);
    	for(int i=1;i<=n;i++)adj[i].clear();
	    for(int i=1;i<n;i++)
	    {
	    	int u,v,l;
	    	cin>>u>>v>>l;
	    	u++,v++;
	    	adj[u].pb({v,l}),adj[v].pb({u,l});
	    }
	    cout<<cal(1)<<'\n';
    }
    return 0;
}
posted @ 2022-10-09 22:33  zyy2001  阅读(20)  评论(0编辑  收藏  举报