【Luogu】P3320寻宝游戏(Splay)
其实这题用Set就完事了但我不会Set
智商-=inf
求虚树上所有边权和的两倍。
具体方式就是splay把所有在虚树上的点存一下,(按照DFS序排序的)每次插入/删除会更新前驱和它、后继和它、前驱和后继的值
#include<cstdio> #include<cstring> #include<algorithm> #include<cctype> #include<cstdlib> #include<map> #define maxn 200020 using namespace std; inline long long read(){ long long num=0,f=1; char ch=getchar(); while(!isdigit(ch)){ if(ch=='-') f=-1; ch=getchar(); } while(isdigit(ch)){ num=num*10+ch-'0'; ch=getchar(); } return num*f; } struct Splay{ struct Node{ int e[2],fa;long long val,size; }tree[maxn]; int root,tot; Splay(){ root=tot=0; } inline int iden(int x){ return x==tree[tree[x].fa].e[1]; } inline void connect(int x,int fa,int how){ tree[x].fa=fa; tree[fa].e[how]=x; } inline void update(int x){ tree[x].size=tree[tree[x].e[0]].size+tree[tree[x].e[1]].size+1; } void rotate(int x){ int y=tree[x].fa; int r=tree[y].fa; int sony=iden(x); int sonr=iden(y); if(root==y) root=x; int b=tree[x].e[sony^1]; connect(b,y,sony); connect(y,x,sony^1); connect(x,r,sonr); update(y); update(x); } void splay(int pos,int to){ to=tree[to].fa; while(tree[pos].fa!=to){ if(tree[tree[pos].fa].fa==to) rotate(pos); else if(iden(tree[pos].fa)==iden(pos)){ rotate(tree[pos].fa),rotate(pos); } else {rotate(pos),rotate(pos); } } } inline int create(int fa,int val){ tree[++tot]=(Node){{0,0},fa,val,1}; return tot; } inline int build(int val){ if(root==0){ root=create(0,val); return root; } int now=root; while(now){ tree[now].size++; int nxt=val<tree[now].val?0:1; if(tree[now].e[nxt]==0){ connect(create(now,val),now,nxt); return tot; } now=tree[now].e[nxt]; } } inline void insert(int val){ int p=build(val); splay(p,root); } inline int find(int val){ int now=root; while(now){ if(tree[now].val==val) return now; int nxt=val<tree[now].val?0:1; now=tree[now].e[nxt]; } return 0; } void dele(int x){ tree[x].e[0]=tree[x].e[1]=0; if(x==tot) tot--; } void pop(int val){ int now=find(val); if(now==0) return; splay(now,root); if(tree[now].e[0]==0){ root=tree[now].e[1]; tree[root].fa=0; dele(now); return; } int deal=tree[now].e[0]; while(tree[deal].e[1]) deal=tree[deal].e[1]; splay(deal,tree[now].e[0]); connect(tree[now].e[1],deal,1); root=deal; tree[deal].fa=0; update(deal); dele(now); return; } inline int lower(int val){ int now=root,ans=-0x7fffffff; while(now){ if(tree[now].val<val&&tree[now].val>ans) ans=tree[now].val; int nxt=val<tree[now].val?0:1; now=tree[now].e[nxt]; } return ans; } inline int upper(int val){ int now=root,ans=0x7fffffff; while(now){ if(tree[now].val>val&&tree[now].val<ans) ans=tree[now].val; int nxt=val<tree[now].val?0:1; now=tree[now].e[nxt]; } return ans; } }s; int d[maxn][21]; long long w[maxn][21]; int dfn[maxn],ID; int deep[maxn]; bool ext[maxn]; int back[maxn]; long long dis[maxn]; struct Edge{ int next,to; long long val; }edge[maxn*2]; int head[maxn],num; inline void add(int from,int to,long long val){ edge[++num]=(Edge){head[from],to,val}; head[from]=num; } void find(int x,int fa){ dfn[x]=++ID; back[ID]=x; deep[x]=deep[fa]+1; for(int i=head[x];i;i=edge[i].next){ int to=edge[i].to; if(to==fa) continue; dis[to]=dis[x]+edge[i].val; d[to][0]=x; w[to][0]=edge[i].val; find(to,x); } return; } long long calcdis(int x,int y,int opt){ long long ans=0; if(deep[x]<deep[y]) swap(x,y); int f=deep[x]-deep[y]; for(int i=0;(1<<i)<=f;++i) if(f&(1<<i)){ ans+=w[x][i]; x=d[x][i]; } if(x==y) return opt==0?ans:x; for(int i=20;i>=0;--i){ if(d[x][i]==d[y][i]) continue; ans+=w[x][i]+w[y][i]; x=d[x][i]; y=d[y][i]; } return opt==0?ans+w[x][0]+w[y][0]:d[x][0]; } int main(){ int n=read(),m=read(); for(int i=1;i<n;++i){ int from=read(),to=read(),val=read(); add(from,to,val); add(to,from,val); } find(1,1); for(int j=1;j<=20;++j) for(int i=1;i<=n;++i){ d[i][j]=d[d[i][j-1]][j-1]; w[i][j]=w[i][j-1]+w[d[i][j-1]][j-1]; } long long ans=0; while(m--){ int e=read(); int dfe=dfn[e]; if(ext[e]==0){ ext[e]=1; int low=s.lower(dfe),upp=s.upper(dfe); if(low>0) ans+=calcdis(back[low],e,0); if(upp<=n) ans+=calcdis(back[upp],e,0); if(low>0&&upp<=n) ans-=calcdis(back[low],back[upp],0); s.insert(dfe); } else{ ext[e]=0; s.pop(dfe); int low=s.lower(dfe),upp=s.upper(dfe); if(low>0) ans-=calcdis(back[low],e,0); if(upp<=n) ans-=calcdis(back[upp],e,0); if(low>0&&upp<=n) ans+=calcdis(back[low],back[upp],0); } int last=s.lower(n+1),first=s.upper(0); long long ret=0; if(last>0&&first<=n){ int lca=calcdis(back[last],back[first],1); ret=dis[back[last]]+dis[back[first]]-2*dis[lca]; } printf("%lld\n",ans+ret); } return 0; }