用dfs序把询问表示成询问dfs序的两个区间中的信息
拆成至多9个询问(询问dfs序的两个前缀),对这些询问用莫队处理,时间复杂度$O(n\sqrt{m})$
#include<bits/stdc++.h> typedef long long i64; const int N=1e5+77; char buf[N],*ptr=buf+100000,ob[N],*op=ob; int G(){ if(ptr-buf==100000)fread(ptr=buf,1,100000,stdin); return *ptr++; } int _(){ int x=0; if(ptr-buf<99900){ while(*ptr<48)++ptr; while(*ptr>47)x=x*10+*ptr++-48; }else{ int c=G(); while(c<48)c=G(); while(c>47)x=x*10+c-48,c=G(); } return x; } #define fl fwrite(ob,1,op-ob,stdout),op=ob void pr(i64 x){ if(op-ob>100000)fl; int ss[25],sp=0; if(!x)*op++=48; while(x)ss[++sp]=x%10,x/=10; while(sp)*op++=ss[sp--]+48; *op++=10; } i64 ans[N*5]; int n,m,rt=1; int v[N],vs[N],qc=0; int es[N*2],enx[N*2],e0[N],ep=2; int fa[N],sz[N],son[N],dep[N],top[N],id[N][2],vi[N],idp=0; void f1(int w,int pa){ dep[w]=dep[fa[w]=pa]+1; sz[w]=1; for(int i=e0[w];i;i=enx[i]){ int u=es[i]; if(u==pa)continue; f1(u,w); sz[w]+=sz[u]; if(sz[u]>sz[son[w]])son[w]=u; } } void f2(int w,int tp){ id[w][0]=++idp; vi[idp]=v[w]; top[w]=tp; if(son[w])f2(son[w],tp); for(int i=e0[w];i;i=enx[i]){ int u=es[i]; if(u!=fa[w]&&u!=son[w])f2(u,u); } id[w][1]=idp; } int up(int x,int y){ int a=top[x],b=top[y]; while(a!=b){ x=fa[a]; if(x==y)return a; a=top[x]; } return son[y]; } bool chk(int w){ return id[w][0]<id[rt][0]&&id[rt][0]<=id[w][1]; } int pos[N],B,qp=0; struct Q{ int l,r,sgn,id; }qs[N*20],qs1[N*20],*ls[N],*lp; int tr[N],tb[N]; i64 s0[N],_ans; int ts[N][2]; inline void inc0(int x){++ts[x][0],_ans+=ts[x][1];} inline void dec0(int x){--ts[x][0],_ans-=ts[x][1];} inline void inc1(int x){++ts[x][1],_ans+=ts[x][0];} inline void dec1(int x){--ts[x][1],_ans-=ts[x][0];} void cal(int w,int*a,int&p){ if(w==rt)a[0]=n,p=1; else if(id[w][0]<id[rt][0]&&id[rt][0]<=id[w][1]){ w=up(rt,w); a[0]=n; a[1]=-id[w][1]; a[2]=id[w][0]-1; p=3; }else{ a[0]=id[w][1]; a[1]=1-id[w][0]; p=2; } } void ins(int a,int b,int id){ if(!(a&&b))return; int c=1; if(a<0)a=-a,c=-c; if(b<0)b=-b,c=-c; if(a>b)std::swap(a,b); if(b==n)ans[id]+=c*s0[a]; else qs[qp++]=(Q){a,b,c,id}; } int main(){ n=_();m=_(); for(int i=1;i<=n;++i)v[i]=vs[i]=_(); std::sort(vs+1,vs+n+1); for(int i=1;i<=n;++i)v[i]=std::lower_bound(vs+1,vs+n+1,v[i])-vs; for(int i=1,a,b;i<n;++i){ a=_(),b=_(); es[ep]=b;enx[ep]=e0[a];e0[a]=ep++; es[ep]=a;enx[ep]=e0[b];e0[b]=ep++; } f1(1,0); f2(1,1); for(int i=1;i<=n;++i)inc0(vi[i]); for(int i=1;i<=n;++i)inc1(vi[i]),s0[i]=_ans; _ans=0; memset(ts,0,sizeof(ts)); for(int i=1;i<=m;++i){ if(_()==1)rt=_(); else{ ++qc; int x=_(),y=_(); int xv[4],xp,yv[4],yp; cal(x,xv,xp); cal(y,yv,yp); for(int a=0;a<xp;++a) for(int b=0;b<yp;++b)ins(xv[a],yv[b],qc); } } B=n/sqrt(qp+1)*2+1; for(int i=1;i<=n;++i)pos[i]=i/B; for(int i=0;i<qp;++i)++tr[qs[i].r],++tb[pos[qs[i].l]]; lp=qs1; for(int i=1;i<=n;++i)ls[i]=lp,lp+=tr[i]; for(int i=0;i<qp;++i)*ls[qs[i].r]++=qs[i]; lp=qs; for(int i=0;i<=pos[n];++i)ls[i]=lp,lp+=tb[i]; for(int i=0;i<qp;++i)*ls[pos[qs1[i].l]]++=qs1[i]; for(int i=0;i<pos[n];i+=2)std::reverse(ls[i],ls[i+1]); int L=0,R=0; for(int i=0;i<qp;++i){ int l=qs[i].l,r=qs[i].r; while(L<l)inc0(vi[++L]); while(L>l)dec0(vi[L--]); while(R<r)inc1(vi[++R]); while(R>r)dec1(vi[R--]); ans[qs[i].id]+=qs[i].sgn*_ans; } for(int i=1;i<=qc;++i)pr(ans[i]); return fl,0; }