POJ1741 Tree(树分治)
题意:
求树上距离小于等于K的点对有多少个
思路:
每次分治,我们首先算出重心,为了计算重心,需要进行两次dfs,第一次把以每个结点为根的子树大小求出来,第二次是从这些结点中找重心
找到重心后,需要统计所有结点到重心的距离,看其中有多少对小于等于K
但是这些求出来满足小于等于K的里面只有那些路径经过重心的点对才是有效的,也就是说在同一颗子树上的肯定不算数的,所以对每颗子树,把子树内部的满足条件的点对减去。
/* *********************************************** Author :devil ************************************************ */ #include <cstdio> #include <cstring> #include <iostream> #include <algorithm> #include <vector> #include <queue> #include <set> #include <stack> #include <map> #include <string> #include <cmath> #include <stdlib.h> #define LL long long #define rep(i,a,b) for(int i=a;i<=b;i++) #define dep(i,a,b) for(int i=a;i>=b;i--) #define ou(a) printf("%d\n",a) #define pb push_back #define mkp make_pair template<class T>inline void rd(T &x) { char c=getchar(); x=0; while(!isdigit(c))c=getchar(); while(isdigit(c)) { x=x*10+c-'0'; c=getchar(); } } #define IN freopen("in.txt","r",stdin); #define OUT freopen("out.txt","w",stdout); using namespace std; const int inf=0x3f3f3f3f; const int mod=1e9+7; const int N=1e4+10; int n,k,x,y,w,root,ans,cnt,m; int sum[N],ma[N],deep[N],d[N]; bool vis[N]; vector<pair<int,int> >eg[N]; void dfs(int u,int fa) { sum[u]=1,ma[u]=0; for(int i=0;i<eg[u].size();i++) { int v=eg[u][i].first; if(v==fa||vis[v]) continue; dfs(v,u); sum[u]+=sum[v]; ma[u]=max(ma[u],sum[v]); } ma[u]=max(ma[u],m-sum[u]); if(ma[u]<ma[root]) root=u; } void getdeep(int u,int fa) { deep[++deep[0]]=d[u]; for(int i=0;i<eg[u].size();i++) { int v=eg[u][i].first; if(v==fa||vis[v]) continue; d[v]=d[u]+eg[u][i].second; getdeep(v,u); } } int cal(int u,int now) { d[u]=now;deep[0]=0; getdeep(u,0); sort(deep+1,deep+1+deep[0]); int ret=0,l=1,r=deep[0]; while(l<r) { if(deep[l]+deep[r]<=k) { ret=ret+r-l; l++; } else r--; } return ret; } void work(int u) { ans+=cal(u,0); vis[u]=1; for(int i=0;i<eg[u].size();i++) { int v=eg[u][i].first; if(vis[v]) continue; ans-=cal(v,eg[u][i].second); m=sum[v]; root=0; dfs(v,0); work(root); } } int main() { #ifndef ONLINE_JUDGE //IN #endif ma[0]=inf; while(~scanf("%d%d",&n,&k)&&n) { for(int i=1; i<=n; i++) eg[i].clear(); root=0;ans=0;cnt=0;m=n; memset(vis,0,sizeof(vis)); for(int i=1; i<n; i++) { scanf("%d%d%d",&x,&y,&w); eg[x].pb(mkp(y,w)); eg[y].pb(mkp(x,w)); } dfs(1,0); work(root); printf("%d\n",ans); } return 0; }