BZOJ1468 - Tree

Description

给出一棵n(n4×104)个节点的带边权树和k,求有多少距离不超过k的点对。

Solution

点分治。
分治(Divide and Conquer),也就是分而治之。具体来说就是将一个问题划分为若干个规模更小的同样的问题直到子问题可以简单求解,并将这些子问题的结果进行合并。在树的分治上,合并的就是子树的结果啦。
首先我们要找到树的重心,然后将重心作为根。树的重心是树中的一个点,其所有的子树中最大的子树节点数最少。这样可以让这棵树尽可能平衡,而且其最大的子树大小不超过n/2——因此分治的层数是O(logn)
接下来以此为基础说明本题的解法。
记树的根节点为rt,树中节点urt的距离为dst[u],则有

k=dst[u]+dst[v]ku,v(u,v)
+k
子树总距离不超过k的点对就是分治中的子问题,递归解决就好。对于另外那部分,我们设计函数cal(u)表示以u为根的树中满足dst[u]+dst[v]k的点对数(注意其中的dst是对于rtdst,而不是对于u的),则这部分的点对数=cal(rt)usonrtcal(u)cal(u)可以通过DFS出u的所有下属节点的dst再排序在O(nlogn)的时间复杂度内实现,具体可以看代码。

时间复杂度为O(nlog2n)

Code

//Tree
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
inline char gc()
{
    static char now[1<<16],*S,*T;
    if(S==T) {T=(S=now)+fread(now,1,1<<16,stdin); if(S==T) return EOF;}
    return *S++;
}
inline int read()
{
    int x=0; char ch=gc();
    while(ch<'0'||'9'<ch) ch=gc();
    while('0'<=ch&&ch<='9') x=x*10+ch-'0',ch=gc();
    return x;
}
int const N=4e4+10;
int const INF=0x7FFFFFFF;
int n,k;
int cnt,h[N];
struct edge{int v,w,nxt;} ed[N<<1];
void edAdd(int u,int v,int w)
{
    ++cnt; ed[cnt].v=v,ed[cnt].w=w,ed[cnt].nxt=h[u],h[u]=cnt;
    ++cnt; ed[cnt].v=u,ed[cnt].w=w,ed[cnt].nxt=h[v],h[v]=cnt;
}
int ans;
int fa[N],siz[N];
bool vst[N];
void dfs(int u)
{
    siz[u]=1;
    for(int i=h[u];i;i=ed[i].nxt)
    {
        int v=ed[i].v;
        if(v!=fa[u]&&!vst[v]) fa[v]=u,dfs(v),siz[u]+=siz[v];
    }
}
int G,maxS,sumS;
void getG(int u)
{
    int mx=0;
    for(int i=h[u];i;i=ed[i].nxt)
    {
        int v=ed[i].v;
        if(v!=fa[u]&&!vst[v]) mx=max(mx,siz[v]),getG(v);
    }
    mx=max(mx,sumS-siz[u]);
    if(mx<maxS) maxS=mx,G=u;
}
int dst[N],tCnt,t[N];
void getDst(int u)
{
    t[++tCnt]=dst[u];
    for(int i=h[u];i;i=ed[i].nxt)
    {
        int v=ed[i].v;
        if(v!=fa[u]&&!vst[v]) dst[v]=dst[u]+ed[i].w,getDst(v);
    }
}
int cal(int u,int d0)
{
    tCnt=0,dst[u]=0,getDst(u);
    sort(t+1,t+tCnt+1);
    int res=0,i=1,j=tCnt;
    while(i<j)
        if(t[i]+t[j]>k-d0-d0) j--;
        else res+=j-i,i++;
    return res;
}
void solve(int u)
{
    vst[u]=true; fa[u]=0,dfs(u);
    ans+=cal(u,0);
    for(int i=h[u];i;i=ed[i].nxt)
    {
        int v=ed[i].v;
        if(vst[v]) continue;
        ans-=cal(v,ed[i].w);
        maxS=INF,sumS=siz[v],getG(v); solve(G);
    }
}
int main()
{
    n=read();
    cnt=0; memset(h,0,sizeof h);
    for(int i=1;i<=n-1;i++)
    {
        int u=read(),v=read(),w=read();
        edAdd(u,v,w);
    }
    k=read();
    ans=0;
    fa[1]=0,dfs(1);
    maxS=INF,sumS=n,getG(1); solve(G);
    printf("%d",ans);
    return 0;
}

P.S.

WA了两发居然是因为读入优化read()的返回值类型写成了char…吐血

posted @ 2018-01-18 23:49  VisJiao  阅读(73)  评论(0编辑  收藏  举报