【HNOI】 攻城略池 tree-dp
【题目大意】
给定一棵树,边有边权,每个节点有一些兵,现在叶子节点在0时刻被占领,并且任意节点在x被占领,那么从x+1开始,每单位时间产生一个兵,兵会顺着父亲节点一直走到根(1),其中每经过一个节点,该节点的兵储量减少1,问所有节点都被攻陷的最晚时间。
【数据范围】
n<=10^5.
首先我们可以设每个节点被攻占的时间为w[i],那么对于一个父节点x,我们可以二分一个答案,来判断这个答案是否合法,那么假设我们分到的值为time,那么对于x子树中所有的节点p,每个节点的贡献为max(0,time-w[p]-d[x][p])。那么我们只需要判断贡献和与x节点的兵的数量就可以了。
那么对于一个点,我们可以用一颗平衡树来存这个点为根节点的子树中所有节点的w[p]+d[x][p]值,那么对于一个二分到的答案,我们只需要判断树中小于time的和,再用time*size减去就好了。对于一个节点,我们只需要启发式合并他所有的子节点的平衡树就好了。
反思:开始写的时候没看long long,结果发现连第二个测试点都过不去,然后开了之后就出了各种各样的问题,开始是return 0没有改成return 1LL,后来发现平衡树中维护的sum值有的时候没有被更新,我是在插入和删除的时候修改的这个,和size一起修改,也不知道哪儿错了,后来就直接在询问的时候维护sum,然后还是不行,后来能加上的地方都加上了才过了。。。。
//By BLADEVIL #include <cstdio> #include <algorithm> #define maxn 100010 #define LL long long using namespace std; LL n,l,tot,save; LL a[maxn],pre[maxn<<1],other[maxn<<1],last[maxn],len[maxn<<1]; LL w[maxn],flag[maxn],que[maxn],dis[maxn],rot[maxn]; LL left[maxn<<5],right[maxn<<5],key[maxn<<5],size[maxn<<5],sum[maxn<<5]; void connect(LL x,LL y,LL z) { pre[++l]=last[x]; last[x]=l; other[l]=y; len[l]=z; } void left_rotate(LL &t) { LL k=right[t]; right[t]=left[k]; left[k]=t; size[k]=size[t]; sum[k]=sum[t]; size[t]=size[left[t]]+size[right[t]]+1LL; sum[t]=sum[left[t]]+sum[right[t]]+key[t]; t=k; } void right_rotate(LL &t) { LL k=left[t]; left[t]=right[k]; right[k]=t; size[k]=size[t]; sum[k]=sum[t]; size[t]=size[left[t]]+size[right[t]]+1LL; sum[t]=sum[left[t]]+sum[right[t]]+key[t]; t=k; } void maintain(LL &t,int flag) { if (!flag) { if (size[left[left[t]]]>size[right[t]]) right_rotate(t); else if (size[right[left[t]]]>size[right[t]]) left_rotate(left[t]),right_rotate(t); else return ; } else { if (size[right[right[t]]]>size[left[t]]) left_rotate(t); else if (size[left[right[t]]]>size[left[t]]) right_rotate(right[t]),left_rotate(t); else return ; } maintain(left[t],0); maintain(right[t],1); maintain(t,1); maintain(t,0); //sum[t]=sum[left[t]]+sum[right[t]]+key[t]; } void t_insert(LL &t,LL v) { if (!t) { t=++tot; left[t]=right[t]=0LL; size[t]=1LL; key[t]=sum[t]=v; } else { size[t]++; sum[t]+=v; if (v<key[t]) t_insert(left[t],v); else t_insert(right[t],v); maintain(t,v>=key[t]); } } /* LL t_delete(LL &t,LL v) { size[t]--; if ((v==key[t])||((v>key[t])&&(!right[t]))||((v<key[t])&&(!left[t]))) { save=key[t]; if ((!left[t])||(!right[t])) t=left[t]+right[t]; else key[t]=t_delete(left[t],v+1LL); } else { if (v<key[t]) return t_delete(left[t],v); else return t_delete(right[t],v); } sum[t]=sum[left[t]]+sum[right[t]]+key[t]; return save; } */ LL t_delete(LL &t,LL v) { if ((v==key[t])||((v>key[t])&&(!right[t]))||((v<key[t])&&(!left[t]))) { save=key[t]; if ((!left[t])||(!right[t])) { t=left[t]+right[t]; sum[t]=sum[left[t]]+sum[right[t]]+key[t]; }else key[t]=t_delete(left[t],v+1LL); //tmp = key[t]; } else { if (v<key[t]) save = t_delete(left[t],v); else save = t_delete(right[t],v); } //size[t]=size[left[t]]+size[right[t]]+1; sum[t]=sum[left[t]]+sum[right[t]]+key[t]; return save; } void combine(LL &t1,LL &flag1,LL t2,LL flag2) { if (size[t1]<size[t2]) swap(t1,t2),swap(flag1,flag2); while (t2) { t_insert(t1,key[t2]+flag2-flag1); t_delete(t2,key[t2]); } } LL judge(LL t,LL time){ sum[t]=sum[left[t]]+sum[right[t]]+key[t]; if (!t) return 0LL; if (key[t]<=time) return judge(right[t],time)+(size[left[t]]+1LL)*time-sum[left[t]]-key[t]; else return judge(left[t],time); } void work() { LL h=0LL,t=1LL; que[1]=1LL; dis[1]=1LL; while (h<t) { LL cur=que[++h]; for (LL p=last[cur];p;p=pre[p]) { if (dis[other[p]]) continue; que[++t]=other[p]; dis[other[p]]=dis[cur]+1LL; } } //for (LL i=1;i<=n;i++) printf("%d ",que[i]); printf("\n"); for (LL i=n;i;i--) { LL cur=que[i]; for (LL p=last[cur];p;p=pre[p]) { if (dis[other[p]]<dis[cur]) continue; combine(rot[cur],flag[cur],rot[other[p]],flag[other[p]]+len[p]); } if ((!rot[cur])||(!a[cur])) { t_insert(rot[cur],-flag[cur]); //printf("%lld %lld %lld\n",cur,rot[cur],flag[cur]); //for (LL i=1;i<=20;i++) printf("%lld %lld %lld %lld %lld %lld\n",i,left[i],right[i],size[i],key[i],sum[i]); continue; } //printf("%lld %lld %lld\n",cur,rot[cur],flag[cur]); //for (LL i=1;i<=20;i++) printf("%lld %lld %lld %lld %lld %lld\n",i,left[i],right[i],size[i],key[i],sum[i]); LL l=1LL,r=1LL<<30,mid,ans; while (l<=r) { //printf("%d %d\n",l,r); mid=l+r>>1LL; //if (cur==1) printf("%lld %lld\n",l,r); if (judge(rot[cur],mid-flag[cur])>=a[cur]) r=mid-1LL,ans=mid; else l=mid+1LL; } w[cur]=ans; //if (cur==1) printf("|%lld\n",judge(rot[cur],-6)); t_insert(rot[cur],w[cur]-flag[cur]); //printf("%lld %lld %lld\n",cur,rot[cur],flag[cur]); //for (LL i=1;i<=20;i++) printf("%lld %lld %lld %lld %lld %lld\n",i,left[i],right[i],size[i],key[i],sum[i]); } //printf("%lld %lld\n",rot[1],flag[1]); //for (LL i=1;i<=20;i++) printf("%lld %lld %lld %lld %lld %lld\n",i,left[i],right[i],size[i],key[i],sum[i]); //for (LL i=1;i<=n;i++) printf("%lld ",w[i]); printf("\n"); LL ans=0LL; for (int i=1;i<=n;i++) ans=max(ans,w[i]); printf("%lld\n",ans); } void check() { LL t1=0,t2=0,flag1=0,flag2=0; for (LL i=1;i<=5;i++) t_insert(t1,i),t_insert(t2,i); t_insert(t2,6); combine(t1,flag1,t2,flag2); printf("%lld\n",t1); for (LL i=1;i<=20;i++) printf("%ld %lld %lld %lld %lld %lld\n",i,left[i],right[i],size[i],key[i],sum[i]); return ; LL t=0; for (LL i=1;i<=10;i++) t_insert(t,i); t_delete(t,7); printf("%lld\n",t); for (LL i=1;i<=10;i++) printf("%lld %lld %lld %lld %lld %lld\n",i,left[i],right[i],size[i],key[i],sum[i]); printf("%lld\n",judge(t,9)); } int main() { //check(); return 0; freopen("conquer.in","r",stdin); freopen("conquer.out","w",stdout); scanf("%lld",&n); for (LL i=1;i<=n;i++) scanf("%lld",&a[i]); for (LL i=1;i<n;i++) { LL x,y,z; scanf("%lld%lld%lld",&x,&y,&z); connect(x,y,z); connect(y,x,z); } work(); fclose(stdin); fclose(stdout); return 0; }