Description
小B最近正在玩一个寻宝游戏,这个游戏的地图中有N个村庄和N-1条道路,并且任何两个村庄之间有且仅有一条路径可达。游戏开始时,玩家可以任意选择一个村庄,瞬间转移到这个村庄,然后可以任意在地图的道路上行走,若走到某个村庄中有宝物,则视为找到该村庄内的宝物,直到找到所有宝物并返回到最初转移到的村庄为止。小B希望评测一下这个游戏的难度,因此他需要知道玩家找到所有宝物需要行走的最短路程。但是这个游戏中宝物经常变化,有时某个村庄中会突然出现宝物,有时某个村庄内的宝物会突然消失,因此小B需要不断地更新数据,但是小B太懒了,不愿意自己计算,因此他向你求助。为了简化问题,我们认为最开始时所有村庄内均没有宝物
Input
第一行,两个整数N、M,其中M为宝物的变动次数。
接下来的N-1行,每行三个整数x、y、z,表示村庄x、y之间有一条长度为z的道路。
接下来的M行,每行一个整数t,表示一个宝物变动的操作。若该操作前村庄t内没有宝物,则操作后村庄内有宝物;若该操作前村庄t内有宝物,则操作后村庄内没有宝物。
Output
M行,每行一个整数,其中第i行的整数表示第i次操作之后玩家找到所有宝物需要行走的最短路程。若只有一个村庄内有宝物,或者所有村庄内都没有宝物,则输出0。
zkw线段树维护有宝藏村庄在原树中的dfs序并支持查询前驱后继
dfs序相邻(包括dfs序最大和最小的点间)的点间距离之和即为最短路程
树链剖分求lca以及两点间距离
#include<cstdio> inline int input(){ int x=0,c=getchar(); while(c>57||c<48)c=getchar(); while(c>47&&c<58)x=x*10+c-48,c=getchar(); return x; } const int N=100055; typedef long long i64; int n,m; int es[N*2],enx[N*2],ev[N*2],e0[N],ep=2; int sz[N],son[N],dep[N],top[N],fa[N],id[N],rid[N],idp=1; int st[N]; i64 len[N],ans=0; void f1(int w,int pa){ fa[w]=pa; dep[w]=dep[pa]+1; sz[w]=1; for(int i=e0[w];i;i=enx[i]){ int u=es[i]; if(u==pa)continue; len[u]=len[w]+ev[i]; f1(u,w); sz[w]+=sz[u]; if(sz[u]>sz[son[w]])son[w]=u; } } void f2(int w,int tp){ top[w]=tp; rid[id[w]=idp++]=w; if(son[w])f2(son[w],tp); for(int i=e0[w];i;i=enx[i]){ int u=es[i]; if(u!=fa[w]&&u!=son[w])f2(u,u); } } i64 dist(int x,int y){ x=rid[x];y=rid[y]; int a=top[x],b=top[y],c; i64 s=len[x]+len[y]; while(a!=b){ if(dep[a]<dep[b])c=a,a=b,b=c,c=x,x=y,y=c; x=fa[a];a=top[x]; } s-=(dep[x]<dep[y]?len[x]:len[y])*2; return s; } inline int min(int a,int b){return a<b?a:b;} inline int max(int a,int b){return a>b?a:b;} inline void mins(int&a,int b){if(a>b)a=b;} inline void maxs(int&a,int b){if(a<b)a=b;} const int inf=0x3f3f3f3f; int mn[262144],mx[262144]; void ins(int x){ int w=x+131074; mn[w]=mx[w]=x; for(w>>=1;w;w>>=1){ mins(mn[w],x); maxs(mx[w],x); } } void del(int x){ int w=x+131074; mn[w]=inf; mx[w]=-inf; for(w>>=1;w;w>>=1){ int lc=w<<1,rc=lc^1; mn[w]=min(mn[lc],mn[rc]); mx[w]=max(mx[lc],mx[rc]); } } i64 cal(int x){ int pv=-inf,nx=inf; for(int w=x+131074;w!=1;w>>=1){ if(w&1)maxs(pv,mx[w^1]); else mins(nx,mn[w^1]); } if(nx==inf)nx=mn[1]; if(pv==-inf)pv=mx[1]; return dist(pv,x)+dist(nx,x)-dist(pv,nx); } int main(){ for(int i=1;i<262144;i++)mn[i]=inf,mx[i]=-inf; n=input();m=input(); for(int i=1,a,b,c;i<n;i++){ a=input();b=input();c=input(); es[ep]=b;enx[ep]=e0[a];ev[ep]=c;e0[a]=ep++; es[ep]=a;enx[ep]=e0[b];ev[ep]=c;e0[b]=ep++; } f1(1,0);f2(1,1); while(m--){ int w=input(); if(st[w]){ st[w]=0; ans-=cal(id[w]); del(id[w]); }else{ st[w]=1; ins(id[w]); ans+=cal(id[w]); } printf("%lld\n",ans); } return 0; }