BZOJ3784树上的路径

题目描述

给定一个N个结点的树,结点用正整数1..N编号。每条边有一个正整数权值。用d(a,b)表示从结点a到结点b路边上经过边的权值。其中要求a<b.将这n*(n-1)/2个距离从大到小排序,输出前M个距离值。

题解

把每次点分治时的dfs序写下来,假设我们在一个位置找能够和它拼成一条链的另一个位置,可以发现那些位置的顺序在dfs序上构成了一段连续区间,用ST表+堆维护。

注意在进队列之前先内啥一下。

代码

#include<iostream>
#include<cstdio>
#include<queue>
#include<cmath>
#define N 50002
#define M 16
using namespace std;
int tot,head[N],lo[N*M],st[M][N*M],size[N],dp[N],sum,now,deep[N],root,n,p[M][N*M];
bool vis[N];
int start,ed;
inline int rd(){
    int x=0;char c=getchar();bool f=0;
    while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    return f?-x:x;
}
struct edge{int n,to,l;}e[N<<1];
inline void add(int u,int v,int l){e[++tot].n=head[u];e[tot].to=v;head[u]=tot;e[tot].l=l;}
struct node{
    int now,l,r,sum;
    node(int nownum=0,int num1=0,int num2=0){
        now=nownum;l=num1;r=num2;
        int loo=lo[r-l+1];
        sum=now+max(st[loo][l],st[loo][r-(1<<loo)+1]);
    }
    int calc(){
        int loo=lo[r-l+1];
        if(st[loo][l]>=st[loo][r-(1<<loo)+1])return p[loo][l];else return p[loo][r-(1<<loo)+1];
    }
    bool operator <(const node &b)const{return sum<b.sum;}
}pa[N*M];
priority_queue<node>q;
void getroot(int u,int fa){
    size[u]=1;dp[u]=0;
    for(int i=head[u];i;i=e[i].n)if(e[i].to!=fa&&!vis[e[i].to]){
        int v=e[i].to;
        getroot(v,u);
        size[u]+=size[v];dp[u]=max(dp[u],size[v]);
    }
    dp[u]=max(dp[u],sum-size[u]);
    if(dp[u]<dp[root])root=u;
}
void getsize(int u,int fa){
    size[u]=1;
    for(int i=head[u];i;i=e[i].n)if(e[i].to!=fa&&!vis[e[i].to]){
        int v=e[i].to;
        getsize(v,u);
        size[u]+=size[v];
    }
}
void getdeep(int u,int fa){
    st[0][++now]=deep[u];p[0][now]=now;
    pa[now]=node(deep[u],start,ed);
    for(int i=head[u];i;i=e[i].n)if(!vis[e[i].to]&&e[i].to!=fa){
        int v=e[i].to;
        deep[v]=deep[u]+e[i].l;
        getdeep(v,u);
    }
}
inline void calc(int u){
    st[0][++now]=0;p[0][now]=now;
    pa[now]=node{0,now,now};
    start=now;ed=now;
    for(int i=head[u];i;i=e[i].n)if(!vis[e[i].to]){
        int v=e[i].to;
        deep[v]=e[i].l;
        getdeep(v,u);
        ed=now;
    }
}
void solve(int u){
    calc(u);vis[u]=1;
    for(int i=head[u];i;i=e[i].n)if(!vis[e[i].to]){
        int v=e[i].to;
        root=n+1;sum=size[v];
        getroot(v,u);getsize(root,0);
        solve(root);
    }
}
int main(){
    n=rd();int k=rd();int u,v,w;
    for(int i=1;i<n;++i){
        u=rd();v=rd();w=rd();
        add(u,v,w);add(v,u,w); 
    }
    dp[root=n+1]=n+1;sum=n;
    getroot(1,0);getsize(root,0);
    solve(root);
    for(int i=1;(1<<i)<=now&&i<M;++i)
      for(int j=1;j+(1<<i)-1<=now;++j)
        st[i][j]=max(st[i-1][j],st[i-1][j+(1<<i-1)]),p[i][j]=st[i-1][j]>=st[i-1][j+(1<<i-1)]?p[i-1][j]:p[i-1][j+(1<<i-1)];
    for(int i=2;i<=now;++i)lo[i]=lo[i>>1]+1;
    for(int i=1;i<=now;++i)pa[i]=node(pa[i].now,pa[i].l,pa[i].r),q.push(pa[i]);///care !!!!
    for(int i=1;i<=k;++i){
        node x=q.top();q.pop();
        printf("%d\n",x.sum);
        int mid=x.calc();
        if(x.l<mid)q.push(node(x.now,x.l,mid-1));
        if(x.r>mid)q.push(node(x.now,mid+1,x.r));
    }
    return 0;
} 
posted @ 2019-02-24 17:27  comld  阅读(262)  评论(0编辑  收藏  举报