BZOJ4285 : 使者
假设询问两点中d[x]<d[y]。
若x是y的祖先,那么就是求起点不在x到y方向的第一个点的子树中,且终点在y子树中的跳跃点个数。
若x不是y的祖先,那么就是求起点在x子树中,且终点在y子树中的跳跃点个数。
利用DFS序可以将转化成矩形,于是利用CDQ分治+扫描线+树状数组即可做到$O(n\log^2n)$。
#include<cstdio> #include<algorithm> #define N 200010 using namespace std; int n,m,q,i,op,x,y,T,pos[N],bit[N],ans[N],cb,cc; int g[N],v[N],nxt[N],ed,f[N],size[N],son[N],top[N],d[N],st[N],en[N],dfn; struct P{int x,y,t;P(){}P(int _x,int _y,int _t){x=_x,y=_y,t=_t;}}a[N],b[N]; inline bool cmpb(const P&a,const P&b){return a.x<b.x;} struct C{int x,l,r,t,p;C(){}C(int _x,int _l,int _r,int _t,int _p){x=_x,l=_l,r=_r,t=_t,p=_p;}}c[N]; inline bool cmpc(const C&a,const C&b){return a.x<b.x;} inline void read(int&a){char c;while(!(((c=getchar())>='0')&&(c<='9')));a=c-'0';while(((c=getchar())>='0')&&(c<='9'))(a*=10)+=c-'0';} inline void add(int x,int y){v[++ed]=y;nxt[ed]=g[x];g[x]=ed;} void dfs(int x){ size[x]=1; for(int i=g[x];i;i=nxt[i])if(v[i]!=f[x]){ d[v[i]]=d[f[v[i]]=x]+1;dfs(v[i]);size[x]+=size[v[i]]; if(size[v[i]]>size[son[x]])son[x]=v[i]; } } void dfs2(int x,int y){ st[x]=++dfn;top[x]=y; if(son[x])dfs2(son[x],y); for(int i=g[x];i;i=nxt[i])if(v[i]!=son[x]&&v[i]!=f[x])dfs2(v[i],v[i]); en[x]=dfn; } inline int lca2(int x,int y){ int t; while(top[x]!=top[y])t=top[y],y=f[top[y]]; return x==y?t:son[x]; } inline void ins(int x,int y){for(;x<=n;x+=x&-x)if(pos[x]<T)pos[x]=T,bit[x]=y;else bit[x]+=y;} inline int ask(int x){int t=0;for(;x;x-=x&-x)if(pos[x]==T)t+=bit[x];return t;} void solve(int l,int r){ if(l==r)return; int mid=(l+r)>>1; solve(l,mid),solve(mid+1,r); int i,j; cb=cc=0; for(i=l;i<=mid;i++){ if(a[i].t<0){ b[cb++]=P(a[i].x,a[i].y,-1); b[cb++]=P(a[i].y,a[i].x,-1); } if(!a[i].t){ b[cb++]=P(a[i].x,a[i].y,1); b[cb++]=P(a[i].y,a[i].x,1); } } if(!cb)return; for(i=r;i>mid;i--){ if(a[i].t==1){ c[cc++]=C(n,st[a[i].y],en[a[i].y],1,i); c[cc++]=C(st[a[i].x]-1,st[a[i].y],en[a[i].y],1,i); c[cc++]=C(en[a[i].x],st[a[i].y],en[a[i].y],-1,i); } if(a[i].t==2){ c[cc++]=C(st[a[i].x]-1,st[a[i].y],en[a[i].y],-1,i); c[cc++]=C(en[a[i].x],st[a[i].y],en[a[i].y],1,i); } } if(!cc)return; if(cb>1)sort(b,b+cb,cmpb); if(cc>1)sort(c,c+cc,cmpc); for(T++,i=j=0;i<cc;i++){ while(j<cb&&b[j].x<=c[i].x)ins(b[j].y,b[j].t),j++; ans[c[i].p]+=c[i].t*(ask(c[i].r)-ask(c[i].l-1)); } } int main(){ read(n); for(i=1;i<n;i++)read(x),read(y),add(x,y),add(y,x); dfs(1),dfs2(1,1); read(q); while(q--)read(x),read(y),a[++m]=P(st[x],st[y],0); read(q); while(q--){ read(op),read(x),read(y); if(op==1)a[++m]=P(st[x],st[y],0); if(op==2)a[++m]=P(st[x],st[y],-1); if(op==3){ if(d[x]>d[y])swap(x,y); if(st[x]<=st[y]&&en[y]<=en[x])a[++m]=P(lca2(x,y),y,1);else a[++m]=P(x,y,2); } } solve(1,m); for(i=1;i<=m;i++)if(a[i].t>0)printf("%d\n",ans[i]); return 0; }