BZOJ4285 : 使者

假设询问两点中d[x]<d[y]。

若x是y的祖先,那么就是求起点不在x到y方向的第一个点的子树中,且终点在y子树中的跳跃点个数。

若x不是y的祖先,那么就是求起点在x子树中,且终点在y子树中的跳跃点个数。

利用DFS序可以将转化成矩形,于是利用CDQ分治+扫描线+树状数组即可做到$O(n\log^2n)$。

 

#include<cstdio>
#include<algorithm>
#define N 200010
using namespace std;
int n,m,q,i,op,x,y,T,pos[N],bit[N],ans[N],cb,cc;
int g[N],v[N],nxt[N],ed,f[N],size[N],son[N],top[N],d[N],st[N],en[N],dfn;
struct P{int x,y,t;P(){}P(int _x,int _y,int _t){x=_x,y=_y,t=_t;}}a[N],b[N];
inline bool cmpb(const P&a,const P&b){return a.x<b.x;}
struct C{int x,l,r,t,p;C(){}C(int _x,int _l,int _r,int _t,int _p){x=_x,l=_l,r=_r,t=_t,p=_p;}}c[N];
inline bool cmpc(const C&a,const C&b){return a.x<b.x;}
inline void read(int&a){char c;while(!(((c=getchar())>='0')&&(c<='9')));a=c-'0';while(((c=getchar())>='0')&&(c<='9'))(a*=10)+=c-'0';}
inline void add(int x,int y){v[++ed]=y;nxt[ed]=g[x];g[x]=ed;}
void dfs(int x){
  size[x]=1;
  for(int i=g[x];i;i=nxt[i])if(v[i]!=f[x]){
    d[v[i]]=d[f[v[i]]=x]+1;dfs(v[i]);size[x]+=size[v[i]];
    if(size[v[i]]>size[son[x]])son[x]=v[i];
  }
}
void dfs2(int x,int y){
  st[x]=++dfn;top[x]=y;
  if(son[x])dfs2(son[x],y);
  for(int i=g[x];i;i=nxt[i])if(v[i]!=son[x]&&v[i]!=f[x])dfs2(v[i],v[i]);
  en[x]=dfn;
}
inline int lca2(int x,int y){
  int t;
  while(top[x]!=top[y])t=top[y],y=f[top[y]];
  return x==y?t:son[x];
}
inline void ins(int x,int y){for(;x<=n;x+=x&-x)if(pos[x]<T)pos[x]=T,bit[x]=y;else bit[x]+=y;}
inline int ask(int x){int t=0;for(;x;x-=x&-x)if(pos[x]==T)t+=bit[x];return t;}
void solve(int l,int r){
  if(l==r)return;
  int mid=(l+r)>>1;
  solve(l,mid),solve(mid+1,r);
  int i,j;
  cb=cc=0;
  for(i=l;i<=mid;i++){
    if(a[i].t<0){
      b[cb++]=P(a[i].x,a[i].y,-1);
      b[cb++]=P(a[i].y,a[i].x,-1);
    }
    if(!a[i].t){
      b[cb++]=P(a[i].x,a[i].y,1);
      b[cb++]=P(a[i].y,a[i].x,1);
    }
  }
  if(!cb)return;
  for(i=r;i>mid;i--){
    if(a[i].t==1){
      c[cc++]=C(n,st[a[i].y],en[a[i].y],1,i);
      c[cc++]=C(st[a[i].x]-1,st[a[i].y],en[a[i].y],1,i);
      c[cc++]=C(en[a[i].x],st[a[i].y],en[a[i].y],-1,i);
    }
    if(a[i].t==2){
      c[cc++]=C(st[a[i].x]-1,st[a[i].y],en[a[i].y],-1,i);
      c[cc++]=C(en[a[i].x],st[a[i].y],en[a[i].y],1,i);
    }
  }
  if(!cc)return;
  if(cb>1)sort(b,b+cb,cmpb);
  if(cc>1)sort(c,c+cc,cmpc);
  for(T++,i=j=0;i<cc;i++){
    while(j<cb&&b[j].x<=c[i].x)ins(b[j].y,b[j].t),j++;
    ans[c[i].p]+=c[i].t*(ask(c[i].r)-ask(c[i].l-1));
  }
}
int main(){
  read(n);
  for(i=1;i<n;i++)read(x),read(y),add(x,y),add(y,x);
  dfs(1),dfs2(1,1);
  read(q);
  while(q--)read(x),read(y),a[++m]=P(st[x],st[y],0);
  read(q);
  while(q--){
    read(op),read(x),read(y);
    if(op==1)a[++m]=P(st[x],st[y],0);
    if(op==2)a[++m]=P(st[x],st[y],-1);
    if(op==3){
      if(d[x]>d[y])swap(x,y);
      if(st[x]<=st[y]&&en[y]<=en[x])a[++m]=P(lca2(x,y),y,1);else a[++m]=P(x,y,2);
    }
  }
  solve(1,m);
  for(i=1;i<=m;i++)if(a[i].t>0)printf("%d\n",ans[i]);
  return 0;
}

  

posted @ 2015-10-01 01:25  Claris  阅读(384)  评论(0编辑  收藏  举报