[八分之三的男人] POJ - 1741 点分治 && 点分治笔记

题意:给出一棵带边权树,询问有多少点对的距离小于等于\(k\)

本题解参考lyd的算法竞赛进阶指南,讲解的十分清晰,比网上那些讲的乱七八糟的好多了

不过写起来还是困难重重(史诗巨作

打完多校更详细做法

对于所有路径,以某个节点u来看分为两种情况
1.经过u的路径
2.不经过u的路径

能对答案有贡献的肯定是1类型,对2我们处理完1后递归求解

于是条件变为
以u为根的树中,其联通块中的点对符合
距离u之和小于等于k ①
且位于各不同的u的子树(u单独处理) ②

不妨先维护①再维护②

每一次选定根u后预处理联通块内的点相对于u的距离dis,以及归属于哪个子树belong,
并且把它们放入数组vec中,清空存在的cnt(cnt[t]表示归属于子树t的点的个数有多少个,其中cnt[u]设为单独节点并不包含其它子树)
处理完后对vec数组的节点按dis排序,这样做的目的是为了\(O(n)\)处理贡献

贡献该怎么算?如果在vec中设两个指针L,R,满足最小的L和最大的R符合\(dis[vec[L]]+dis[vec[R]]≤k\),这就满足了①条件
那么②呢?显然是在统计完\([L+1,R]\)范围内的\(cnt[belong[vec[i]]]\)后求\(R-(L+1)-1-cnt[belong[vec[L]]]\)

这个时候vec对dis排序就显得很有用,k是恒定的,随着L指针的右移,R指针也一定是左移,直到\(L=R\)就表明以后无论怎么移动都不会符合条件①了
在指针移动的过程中我们顺便完成了对cnt的统计更新,所以只需L从头到尾扫一遍,每个元素至多被指针遍历4次,整个操作是\(O(n)\)的(虽然排序是nlogn的

由此我们完成了对u有关的第一个情况的所有路径统计
第二个情况只需把u删去(block标记,删去也意味着所有经过u的路径不会再遍历,因为不会再有任何贡献了)递归接下来的子树即可

为了避免单链型的\(O(n)\)次递归,我们选择将每一个子树的重心作为根进行处理,以达到\(O(logn)\)的最坏情况
具体的先dfs一下更新子树大小得出联通块大小V再不断比对删除节点后的最大联通块的最小值就能找到了

总而言之我们在\(O(nlog^2n)\)的时间完成了传说中的点分治啦

八分之三的男人达成√

#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<string>
#include<vector>
#include<stack>
#include<queue>
#include<set>
#include<map>
//#include<unordered_map>
#define rep(i,j,k) for(register int i=j;i<=k;i++)
#define rrep(i,j,k) for(register int i=j;i>=k;i--)
#define erep(i,u) for(register int i=head[u];~i;i=nxt[i])
#define print(a) printf("%lld",(ll)a)
#define println(a) printf("%lld\n",(ll)a)
using namespace std;
const int MAXN = 1e5+11;
const int INF = 0x3f3f3f3f;
const double EPS = 1e-7;
typedef long long ll;
ll read(){
    ll x=0,f=1;register char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
int n,k;
bool vis[MAXN],block[MAXN];
int dis[MAXN],belong[MAXN],cnt[MAXN],sz[MAXN];
int to[MAXN<<1],nxt[MAXN<<1],cost[MAXN<<1],head[MAXN],tot;
vector<int> vec;
void init(){
    memset(head,-1,sizeof head);
    memset(vis,0,sizeof vis);
    memset(block,0,sizeof block);
    tot=0;
}
void add(int u,int v,int w){
    to[tot]=v;cost[tot]=w;
    nxt[tot]=head[u];head[u]=tot++;
}
int mxsize,mxson,V;//V==N // each V of subtrees can obtain using sz
int getroot(int u,int fa){
    sz[u]=1;
    int tmp=0;//max size of subtrees while u deleted
    erep(i,u){
        int v=to[i];
        if(v==fa) continue;
        if(block[v]) continue;
        getroot(v,u);
        sz[u]+=sz[v];
        tmp=max(tmp,sz[v]);
    }
    tmp=max(tmp,V-sz[u]);
    if(tmp<mxsize){
        mxsize=tmp;
        mxson=u;// top and down compared
    }
    return mxson;
}
bool cmp(int a,int b){return dis[a]<dis[b];}
void prepare(int u,int fa,int d,int son,int rt){
    dis[u]=d; vec.push_back(u);
    belong[u]=son; cnt[son]=0;
    erep(i,u){
        int v=to[i],w=cost[i];
        if(v==fa) continue;
        if(block[v]) continue;
        prepare(v,u,d+w,son==rt?v:son,rt);
    }
}
int getV(int u,int fa){
    sz[u]=1;
    erep(i,u){
        int v=to[i];
        if(v==fa) continue;
        if(block[v]) continue;
        getV(v,u);
        sz[u]+=sz[v];
    }
    return sz[u];
}

ll solve(int u,int fa){
    if(vis[u]||u<1) return 0;
    vis[u]=1;
    
    ll ans=0;
    vec.clear();
    prepare(u,fa,0,u,u);
    sort(vec.begin(),vec.end(),cmp);
    int L,R=vec.size();R--;
    for(int i=0;i<vec.size();i++) cnt[belong[vec[i]]]++;
    bool flag=0;
    for(int i=0;i+1<(int)vec.size();i++){//enum L
        L=i;cnt[belong[vec[L]]]--;
        while(dis[vec[L]]+dis[vec[R]]>k){
            if(L>=R) break;
            cnt[belong[vec[R]]]--;
            R--;
        }
        if(L>=R) break;
        ans+=(ll)R-L-cnt[belong[vec[L]]];
    }
    while(L<vec.size()) cnt[belong[vec[L++]]]=0;
    block[u]=1;
    erep(i,u){
        int v=to[i];
        if(v==fa) continue;
        if(block[v]) continue;
        V=getV(v,u);
        mxson=v,mxsize=INF;
        int rt=getroot(v,u);
        ans+=solve(rt,u);
    }
    return ans;
}
int main(){
    while(cin>>n>>k){
        if(n==0&&k==0) break;
        init();
        rep(i,1,n-1){
            int u=read();
            int v=read();
            int w=read();
            add(u,v,w);
            add(v,u,w);
        }
        mxsize=INF;mxson=1;V=n;
        int rt=getroot(1,-1);
        println(solve(rt,-1));
    }
    return 0;
}
posted @ 2018-07-21 11:45  Caturra  阅读(150)  评论(0编辑  收藏  举报