【HNOI】 攻城略池 tree-dp

  【题目大意】

    给定一棵树,边有边权,每个节点有一些兵,现在叶子节点在0时刻被占领,并且任意节点在x被占领,那么从x+1开始,每单位时间产生一个兵,兵会顺着父亲节点一直走到根(1),其中每经过一个节点,该节点的兵储量减少1,问所有节点都被攻陷的最晚时间。

  【数据范围】

    n<=10^5.

  首先我们可以设每个节点被攻占的时间为w[i],那么对于一个父节点x,我们可以二分一个答案,来判断这个答案是否合法,那么假设我们分到的值为time,那么对于x子树中所有的节点p,每个节点的贡献为max(0,time-w[p]-d[x][p])。那么我们只需要判断贡献和与x节点的兵的数量就可以了。

  那么对于一个点,我们可以用一颗平衡树来存这个点为根节点的子树中所有节点的w[p]+d[x][p]值,那么对于一个二分到的答案,我们只需要判断树中小于time的和,再用time*size减去就好了。对于一个节点,我们只需要启发式合并他所有的子节点的平衡树就好了。

  反思:开始写的时候没看long long,结果发现连第二个测试点都过不去,然后开了之后就出了各种各样的问题,开始是return 0没有改成return 1LL,后来发现平衡树中维护的sum值有的时候没有被更新,我是在插入和删除的时候修改的这个,和size一起修改,也不知道哪儿错了,后来就直接在询问的时候维护sum,然后还是不行,后来能加上的地方都加上了才过了。。。。

  

//By BLADEVIL
#include <cstdio>
#include <algorithm>
#define maxn 100010
#define LL long long

using namespace std;

LL n,l,tot,save;
LL a[maxn],pre[maxn<<1],other[maxn<<1],last[maxn],len[maxn<<1];
LL w[maxn],flag[maxn],que[maxn],dis[maxn],rot[maxn];
LL left[maxn<<5],right[maxn<<5],key[maxn<<5],size[maxn<<5],sum[maxn<<5];

void connect(LL x,LL y,LL z) {
    pre[++l]=last[x];
    last[x]=l;
    other[l]=y;
    len[l]=z;
}

void left_rotate(LL &t) {
    LL k=right[t];
    right[t]=left[k];
    left[k]=t;
    size[k]=size[t];
    sum[k]=sum[t];
    size[t]=size[left[t]]+size[right[t]]+1LL;
    sum[t]=sum[left[t]]+sum[right[t]]+key[t];
    t=k;
}

void right_rotate(LL &t) {
    LL k=left[t];
    left[t]=right[k];
    right[k]=t;
    size[k]=size[t];
    sum[k]=sum[t];
    size[t]=size[left[t]]+size[right[t]]+1LL;
    sum[t]=sum[left[t]]+sum[right[t]]+key[t];
    t=k;
}

void maintain(LL &t,int flag) {
    if (!flag) {
        if (size[left[left[t]]]>size[right[t]])
            right_rotate(t); else 
        if (size[right[left[t]]]>size[right[t]])
            left_rotate(left[t]),right_rotate(t); else return ;
    } else {
        if (size[right[right[t]]]>size[left[t]])
            left_rotate(t); else 
        if (size[left[right[t]]]>size[left[t]])
            right_rotate(right[t]),left_rotate(t); else return ;
    }
    maintain(left[t],0); maintain(right[t],1);
    maintain(t,1); maintain(t,0);
    //sum[t]=sum[left[t]]+sum[right[t]]+key[t];
}

void t_insert(LL &t,LL v) {
    if (!t) {
        t=++tot;
        left[t]=right[t]=0LL;
        size[t]=1LL;
        key[t]=sum[t]=v;
    } else {
        size[t]++; sum[t]+=v;
        if (v<key[t]) t_insert(left[t],v); else t_insert(right[t],v);
        maintain(t,v>=key[t]);
    }
}
/*
LL t_delete(LL &t,LL v) {
    size[t]--;
    if ((v==key[t])||((v>key[t])&&(!right[t]))||((v<key[t])&&(!left[t]))) {
        save=key[t];
        if ((!left[t])||(!right[t]))
            t=left[t]+right[t]; else key[t]=t_delete(left[t],v+1LL);
    } else {
        if (v<key[t]) return t_delete(left[t],v); else return t_delete(right[t],v);
    }
    sum[t]=sum[left[t]]+sum[right[t]]+key[t];
    return save;
}
*/

LL t_delete(LL &t,LL v) {
    if ((v==key[t])||((v>key[t])&&(!right[t]))||((v<key[t])&&(!left[t]))) {
        save=key[t];
        if ((!left[t])||(!right[t])) {
            t=left[t]+right[t]; 
            sum[t]=sum[left[t]]+sum[right[t]]+key[t];
        }else key[t]=t_delete(left[t],v+1LL);
        //tmp = key[t];
    } else {
        if (v<key[t]) save =  t_delete(left[t],v); else save = t_delete(right[t],v);
    }
    //size[t]=size[left[t]]+size[right[t]]+1;
    sum[t]=sum[left[t]]+sum[right[t]]+key[t];
    return save;
}

void combine(LL &t1,LL &flag1,LL t2,LL flag2) {
    if (size[t1]<size[t2]) swap(t1,t2),swap(flag1,flag2);
    while (t2) {
        t_insert(t1,key[t2]+flag2-flag1);
        t_delete(t2,key[t2]);
    }
}

LL judge(LL t,LL time){
    sum[t]=sum[left[t]]+sum[right[t]]+key[t];
    if (!t) return 0LL;
    if (key[t]<=time) 
        return judge(right[t],time)+(size[left[t]]+1LL)*time-sum[left[t]]-key[t]; else 
        return judge(left[t],time);
}

void work() {
    LL h=0LL,t=1LL;
    que[1]=1LL; dis[1]=1LL;
    while (h<t) {
        LL cur=que[++h];
        for (LL p=last[cur];p;p=pre[p]) {
            if (dis[other[p]]) continue;
            que[++t]=other[p];
            dis[other[p]]=dis[cur]+1LL;
        }
    }
    //for (LL i=1;i<=n;i++) printf("%d ",que[i]); printf("\n");
    for (LL i=n;i;i--) {
        LL cur=que[i];
        for (LL p=last[cur];p;p=pre[p]) {
            if (dis[other[p]]<dis[cur]) continue;
            combine(rot[cur],flag[cur],rot[other[p]],flag[other[p]]+len[p]);
        }
        if ((!rot[cur])||(!a[cur])) {
            t_insert(rot[cur],-flag[cur]);
            //printf("%lld %lld %lld\n",cur,rot[cur],flag[cur]);
            //for (LL i=1;i<=20;i++) printf("%lld %lld %lld %lld %lld %lld\n",i,left[i],right[i],size[i],key[i],sum[i]);
            continue;
        }
        //printf("%lld %lld %lld\n",cur,rot[cur],flag[cur]);
        //for (LL i=1;i<=20;i++) printf("%lld %lld %lld %lld %lld %lld\n",i,left[i],right[i],size[i],key[i],sum[i]);
        LL l=1LL,r=1LL<<30,mid,ans;
        while (l<=r) {
            //printf("%d %d\n",l,r);    
            mid=l+r>>1LL;
            //if (cur==1) printf("%lld %lld\n",l,r);
            if (judge(rot[cur],mid-flag[cur])>=a[cur]) r=mid-1LL,ans=mid; else l=mid+1LL;
        }
        w[cur]=ans;
        //if (cur==1) printf("|%lld\n",judge(rot[cur],-6));
        t_insert(rot[cur],w[cur]-flag[cur]);
        //printf("%lld %lld %lld\n",cur,rot[cur],flag[cur]);
        //for (LL i=1;i<=20;i++) printf("%lld %lld %lld %lld %lld %lld\n",i,left[i],right[i],size[i],key[i],sum[i]);
    }
    //printf("%lld %lld\n",rot[1],flag[1]);
    //for (LL i=1;i<=20;i++) printf("%lld %lld %lld %lld %lld %lld\n",i,left[i],right[i],size[i],key[i],sum[i]);
    //for (LL i=1;i<=n;i++) printf("%lld ",w[i]); printf("\n");
    LL ans=0LL;
    for (int i=1;i<=n;i++) ans=max(ans,w[i]);
    printf("%lld\n",ans);
}

void check() {
    LL t1=0,t2=0,flag1=0,flag2=0;
    for (LL i=1;i<=5;i++) t_insert(t1,i),t_insert(t2,i); t_insert(t2,6);
    combine(t1,flag1,t2,flag2);
    printf("%lld\n",t1);
    for (LL i=1;i<=20;i++) printf("%ld %lld %lld %lld %lld %lld\n",i,left[i],right[i],size[i],key[i],sum[i]);
    return ;
    LL t=0;
    for (LL i=1;i<=10;i++) t_insert(t,i);
    t_delete(t,7); printf("%lld\n",t);
    for (LL i=1;i<=10;i++) printf("%lld %lld %lld %lld %lld %lld\n",i,left[i],right[i],size[i],key[i],sum[i]);
    printf("%lld\n",judge(t,9));
}

int main() {
    //check(); return 0;
    freopen("conquer.in","r",stdin); freopen("conquer.out","w",stdout);
    scanf("%lld",&n);
    for (LL i=1;i<=n;i++) scanf("%lld",&a[i]);
    for (LL i=1;i<n;i++) {
        LL x,y,z; scanf("%lld%lld%lld",&x,&y,&z);
        connect(x,y,z); connect(y,x,z);
    }
    work();
    fclose(stdin); fclose(stdout);
    return 0;
}

 

 

posted on 2014-04-01 11:03  BLADEVIL  阅读(340)  评论(0编辑  收藏  举报