淀粉质模板 Tree

Tree

题目描述

给你一棵TREE,以及这棵树上边的距离.问有多少对点它们两者间的距离小于等于K

输入输出格式

输入格式:

N(n<=40000) 接下来n-1行边描述管道,按照题目中写的输入 接下来是k

输出格式:

一行,有多少对点之间的距离小于等于k


淀粉质感觉怎么写都不好看啊,迷。。

实现方法非常多。

大概思路:
对每一个子树的二层子节点进行遍历,处理每个点所属的二层子节点和到根节点的距离

以到根节点的距离为关键字排序,从两边进行扫描

如果当前满足,答案就加上\(r-l-\)\(l\)属于同一颗二层子节点的数的数量,后者可以直接拿一个桶边扫描边维护

这个每次可以统计答案的区间是逐渐缩小的,有单调性。每次统计时候的意义是对\(l\)位置的节点,有多个点可以跨过根和它配对。

然后递归处理子树的答案。注意每次选择重心作为根节点保证复杂度。


Code:

#include <cstdio>
#include <algorithm>
const int N=4e4+10;
const int inf=0x3f3f3f3f;
int head[N],to[N<<1],Next[N<<1],edge[N<<1],cnt;
void add(int u,int v,int w)
{
    to[++cnt]=v,Next[cnt]=head[u],edge[cnt]=w,head[u]=cnt;
}
struct node
{
    int b,d;
    node(){}
    node(int b,int d){this->b=b,this->d=d;}
    bool friend operator <(node n1,node n2){return n1.d<n2.d;}
}a[N];
int mi,id,ans,n,m,k,coun[N],siz[N],del[N];
int max(int x,int y){return x>y?x:y;}
void get_g(int now,int fa,int su)
{
    siz[now]=1;int mx=0;
    for(int i=head[now];i;i=Next[i])
    {
        int v=to[i];
        if(!del[v]&&v!=fa)
        {
            get_g(v,now,su);
            mx=max(mx,siz[v]);
            siz[now]+=siz[v];
        }
    }
    mx=max(mx,su-siz[now]);
    if(mx<mi) mi=mx,id=now;
}
void dfs(int now,int fa,int anc,int dis)
{
    a[++cnt]=node(anc,dis);
    siz[now]=1;
    for(int i=head[now];i;i=Next[i])
    {
        int v=to[i];
        if(!del[v]&&v!=fa)
            dfs(v,now,anc,dis+edge[i]),siz[now]+=siz[v];
    }
}
void divide(int now,int su)
{
    mi=inf,cnt=0;
    get_g(now,0,su);
    now=id;del[now]=1;
    a[++cnt]=node(now,0),coun[now]=1;
    for(int i=head[now];i;i=Next[i])
    {
        int v=to[i];
        if(!del[v])
            dfs(v,now,v,edge[i]),coun[v]=siz[v];
    }
    std::sort(a+1,a+1+cnt);
    int l=1,r=cnt;
    while(l<r)
    {
        while(l<r&&a[r].d+a[l].d>k) --coun[a[r--].b];
        if(a[r].d+a[l].d<=k) ans+=r-l-coun[a[l].b]+1;
        --coun[a[l++].b];
    }
    for(int i=head[now];i;i=Next[i])
    {
        int v=to[i];
        if(!del[v])
            divide(v,siz[v]);
    }
}
int main()
{
    scanf("%d",&n);
    for(int u,v,w,i=1;i<n;i++)
    {
        scanf("%d%d%d",&u,&v,&w);
        add(u,v,w),add(v,u,w);
    }
    scanf("%d",&k);
    divide(1,n);
    printf("%d\n",ans);
    return 0;
}


2018.9.15

posted @ 2018-09-15 15:02  露迭月  阅读(196)  评论(0编辑  收藏  举报