洛谷P3066 [USACO12DEC]逃跑的Barn (线段树合并)

题目描述
It's milking time at Farmer John's farm, but the cows have all run away! Farmer John needs to round them all up, and needs your help in the search.

FJ's farm is a series of N (1 <= N <= 200,000) pastures numbered 1...N connected by N - 1 bidirectional paths. The barn is located at pasture 1, and it is possible to reach any pasture from the barn.

FJ's cows were in their pastures this morning, but who knows where they ran to by now. FJ does know that the cows only run away from the barn, and they are too lazy to run a distance of more than L. For every pasture, FJ wants to know how many different pastures cows starting in that pasture could have ended up in.

Note: 64-bit integers (int64 in Pascal, long long in C/C++ and long in Java) are needed to store the distance values.

给出以1号点为根的一棵有根树,问每个点的子树中与它距离小于等于l的点有多少个。

输入输出格式
输入格式:
* Line 1: 2 integers, N and L (1 <= N <= 200,000, 1 <= L <= 10^18)

* Lines 2..N: The ith line contains two integers p_i and l_i. p_i (1 <= p_i < i) is the first pasture on the shortest path between pasture i and the barn, and l_i (1 <= l_i <= 10^12) is the length of that path.

输出格式:
* Lines 1..N: One number per line, the number on line i is the number pastures that can be reached from pasture i by taking roads that lead strictly farther away from the barn (pasture 1) whose total length does not exceed L.

输入输出样例
输入样例#1:
4 5
1 4
2 3
1 5
输出样例#1:
3
2
1
1
说明
Cows from pasture 1 can hide at pastures 1, 2, and 4.

Cows from pasture 2 can hide at pastures 2 and 3.

Pasture 3 and 4 are as far from the barn as possible, and the cows can hide there.

 

没错,又是线段树合并的做法
只需要把深度离散化一下,对每个点建一棵权值线段树,然后区间查询就可以了
详细做法是在dfs到一个点的时候将这个点的深度插入该点对应的权值线段树,然后将所有子树的线段树合并到该点的线段树上,查询该点深度到该点深度+l的区间上有几个点就行了

 

代码如下:

#include<bits/stdc++.h>
#define lson tr[now].l
#define rson tr[now].r
#define pii pair<int,long long>
#define mp make_pair
using namespace std;

struct tree
{
    int l,r,sum;
}tr[5000020];
vector<pii> g[400010];
int n,cnt,cnt2,rt[400010],deep[400010],q[400010],ans[400010];
long long gg,tmp[400010],dis[400010];
int N=400000;

int push_up(int now)
{
    tr[now].sum=tr[lson].sum+tr[rson].sum;
}

int insert(int &now,int l,int r,int pos,int val)
{
    if(!now) now=++cnt;
    if(l==r)
    {
        tr[now].sum+=val;
        return 0;
    }
    int mid=(l+r)>>1;
    if(pos<=mid)
    {
        insert(lson,l,mid,pos,val);
    }
    else
    {
        insert(rson,mid+1,r,pos,val);
    }
    push_up(now);
}

int query(int now,int l,int r,int ll,int rr)
{
    if(ll<=l&&r<=rr)
    {
        return tr[now].sum;
    }
    int mid=(l+r)>>1;
    if(rr<=mid)
    {
        return query(lson,l,mid,ll,rr);
    }
    else
    {
        if(mid<ll)
        {
            return query(rson,mid+1,r,ll,rr);
        }
        else
        {
            return query(lson,l,mid,ll,mid)+query(rson,mid+1,r,mid+1,rr);
        }
    }
}

int merge(int a,int b,int l,int r)
{
    if(!a) return b;
    if(!b) return a;
    if(l==r)
    {
        tr[a].sum+=tr[b].sum;
        return a;
    }
    int mid=(l+r)>>1;
    tr[a].l=merge(tr[a].l,tr[b].l,l,mid);
    tr[a].r=merge(tr[a].r,tr[b].r,mid+1,r);
    push_up(a);
    return a;
}

int dfs(int now,int fa,long long dep)
{
    dis[now]=dep;
    tmp[++cnt2]=dep;
    tmp[++cnt2]=dep+gg;
    rt[now]=now;
    ++cnt;
    for(int i=0;i<g[now].size();i++)
    {
        if(g[now][i].first==fa) continue;
        dfs(g[now][i].first,now,dep+g[now][i].second);
    }
}

int solve(int now,int fa)
{
    insert(rt[now],1,N,deep[now],1);
    for(int i=0;i<g[now].size();i++)
    {
        if(g[now][i].first==fa) continue;
        solve(g[now][i].first,now);
        merge(rt[now],rt[g[now][i].first],1,N);
    }
    ans[now]=query(rt[now],1,N,deep[now],q[now]);
}

int init()
{
    sort(tmp+1,tmp+cnt2+1);
    int tot=unique(tmp+1,tmp+cnt2+1)-tmp-1;
    for(int i=1;i<=n;i++)
    {
        deep[i]=lower_bound(tmp+1,tmp+tot+1,dis[i])-tmp;
    }    
    for(int i=1;i<=n;i++)
    {
        q[i]=lower_bound(tmp+1,tmp+tot+1,dis[i]+gg)-tmp;
    }
}

int main()
{
    ios::sync_with_stdio(0);
    cin>>n>>gg;
    int from;
    long long to;
    for(int i=2;i<=n;i++)
    {
        cin>>from>>to;
        g[i].push_back(mp(from,to));
        g[from].push_back(mp(i,to));
    } 
    dfs(1,0,0);
    init();
    solve(1,0);
    for(int i=1;i<=n;i++)
    {
        cout<<ans[i]<<endl;
    }
}

 

posted @ 2018-10-17 19:36  Styx-ferryman  阅读(214)  评论(0编辑  收藏  举报