pair

这个是非常经典的树分治的题目,关于60分的做法.参见poj1741

按照树分治的惯例,先全局统计,然后再减掉重复的东西.

那么如何计算贡献呢?

我们按照poj1741的方法.先将满足一维的情况的数据全部放入一个树状数组里面,然后我们就能够一维统计了.

复杂度O(nlog2n)

代码也比较凑合....

#include<cstdlib>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<deque>
#ifdef WIN32
#define fmt64 "%I64d"
#else
#define fmt64 "%lld"
#endif
using namespace std;
const int maxn = (int)1.5e5, inf = 0x3f3f3f3f;
struct E{
    int t,w,last;
}e[maxn * 2];
struct Q{
    int w,l;
}data[maxn];
int cmp(Q x, Q y){
    if(x.w != y.w) return x.w < y.w;
    return x.l < y.l;
}
int last[maxn],cnt;
int node,tmax;
int n,L,W;
int size;
struct BIT{
    int bt[maxn];
    int lowbit(int x){
        return x & (-x);
    }
    void ins(int pos,int val){
        while(pos <= n){
            bt[pos] += val;
            pos += lowbit(pos);
        }
    }
    int qer(int pos){
        int ret = 0;
        while(pos > 0){
            ret += bt[pos];
            pos -= lowbit(pos);
        }
        return ret;
    }
}bit;
void add(int x,int y,int w){
    e[++cnt] = (E){y,w,last[x]}; last[x] = cnt;
    e[++cnt] = (E){x,w,last[y]}; last[y] = cnt;
}
int sz[maxn],vis[maxn];
int root;
void getroot(int x,int fa){
    sz[x] = 0;
    int Max = 0;
    for(int i = last[x]; i; i = e[i].last)
        if(e[i].t != fa && !vis[e[i].t]){
            getroot(e[i].t, x);
            Max = max(Max, sz[e[i].t] + 1);
            sz[x] += sz[e[i].t] + 1;
        }
    Max = max(Max, node - sz[x] - 1);
    if(tmax >= Max) root = x, tmax = Max;
}
int add1,add2;
long long ans;
void getdata(int x,int fa,int sumw,int suml){                    
    int cost1 = sumw + add1, cost2 = suml + add2;
    if(cost1 <= W && cost2 <= L)
        data[++size] = (Q){cost1,cost2};
    for(int i = last[x]; i; i = e[i].last)
        if(e[i].t != fa && !vis[e[i].t])
            getdata(e[i].t, x, sumw + e[i].w, suml + 1);
}
long long process(){
    sort(data + 1, data + size + 1, cmp);
    int h = 1,t = size;
    long long ret = 0;
    deque<Q>q;
    deque<int>qq;
    while(h < t || !q.empty()){
        while(!q.empty() && qq.back() <= h){
            bit.ins(q.back().l, -1);
            q.pop_back(), qq.pop_back();
        }
        while(!q.empty() && q.front().w + data[h].w > W){
            bit.ins(q.front().l, -1);
            q.pop_front(), qq.pop_front();
        }
        if(data[t].w + data[h].w > W)
            { t--; continue; }
        else{
            while(h < t && data[t].w + data[h].w <= W){
                bit.ins(data[t].l, 1);
                q.push_back(data[t]); qq.push_back(t--);
            }
            ret += bit.qer(L - data[h++].l);
        }        
    }
    return ret;
}
long long calc(int x,int ad1,int ad2){    
    size = 0;
    add1 = ad1; add2 = ad2;
    getdata(x,0,0,0);
    return process();
}
void solve(int x){
    ans += calc(x, 0, 0);
    vis[x] = 1;
    for(int i = last[x]; i; i = e[i].last)
        if(!vis[e[i].t]){
            ans -= calc(e[i].t, e[i].w, 1);
            node = sz[e[i].t];
            tmax = inf;
            getroot(e[i].t, root = 0);
            solve(root);
        }
}
void work(){
    node = n; tmax = inf;
    getroot(1,0);
    solve(root);
}
int main()
{
    freopen("pair.in","r",stdin);
    freopen("pair.out","w",stdout);
    scanf("%d %d %d",&n,&L,&W);
    for(int i = 2; i <= n; ++i){
        int p,w; scanf("%d %d",&p,&w);
        add(i,p,w);
    }
    work();
    printf(fmt64"\n", ans);
    return 0;
}
View Code

 

posted @ 2015-01-04 21:49  Mr.Ren~  阅读(209)  评论(0编辑  收藏  举报