BZOJ4539: [Hnoi2016]树
复制的树缩点,主席树查k小,毫无技术含量,纯码农题。
#include<bits/stdc++.h> #define u first #define v second #define F lower_bound #define I (i+j+2>>1) #define J (i+j>>1) using namespace std; int n1,n2,m,n4; typedef long long ll; map<ll,int>nu; const int N=1e5+5; struct edge{ int v;edge*s; }z[N*2]; edge*a=z,*h[N]; void ins(int u,int v){ edge s={v,h[u]}; *(h[u]=a++)=s; } typedef int arr[N]; arr l,r,b,po,id,d[2],p[2][17]; ll n3,u,v,c[N]; void dfs(int u){ r[id[l[u]=++n4]=u]=1; for(edge*i=h[u];i;i=i->s) if(i->v^p[0][0][u]){ d[0][i->v]=d[0][p[0][0][i->v]=u]+1; dfs(i->v); r[u]+=r[i->v]; } } typedef struct node*ptr; struct node{ ptr i,j;int s; }e[N][17]; void ins(int i,int j,int s,ptr u,ptr v){ while(i<j){ *v=*u; if(s>J)u=u->j,v=v->j=v+1,i=J+1; else ++v->s,u=u->i,v=v->i=v+1,j=I-1; } } int ask(int i,int j,int k,ptr u,ptr v){ while(i<j){ int s=v->s-u->s; if(k<=s)u=u->i,v=v->i,j=I-1; else k-=s,u=u->j,v=v->j,i=J+1; } return i; } int lca(int i,int s,int t){ if(d[i][s]<d[i][t])swap(s,t); int k=d[i][s]-d[i][t]; for(int j=16;~j;--j) if(k>>j&1)s=p[i][j][s]; if(s==t)return s; for(int j=16;~j;--j) if(p[i][j][s]^p[i][j][t]) s=p[i][j][s],t=p[i][j][t]; return p[i][0][s]; } typedef pair<int,int>vec; typedef pair<vec,int>tri; tri ask(ll v){ typeof(nu.end())j=nu.F(v); int s=po[j->v]; return tri(vec(ask(1,n1,v-j->u+r[s],e[l[s]-1],e[l[s]+r[s]-1]),s),j->v); } int ask(int s,int k){ for(int j=16;~j;--j) if(k>>j&1)s=p[1][j][s]; return s; } int main(){ scanf("%d%d%d",&n1,&n2,&m),++n2; for(int i=2;i<=n1;++i) scanf("%lld%lld",&u,&v),ins(u,v),ins(v,u); dfs(po[nu[n3=n1]=1]=1); e[0][0]=(node){e[0],e[0]}; for(int i=1;i<=n1;++i) ins(1,n1,id[i],e[i-1],e[i]); for(int i=2;i<=n2;++i){ scanf("%lld%lld",&u,&v); tri s=ask(v); d[1][i]=d[1][p[1][0][i]=s.v]+1,c[i]=c[s.v]+d[0][b[i]=s.u.u]-d[0][s.u.v]+1,po[nu[n3+=r[u]]=i]=u; } for(int i=1;i<17;++i){ for(int j=1;j<=n1;++j) p[0][i][j]=p[0][i-1][p[0][i-1][j]]; for(int j=2;j<=n2;++j) p[1][i][j]=p[1][i-1][p[1][i-1][j]]; } while(m--){ scanf("%lld%lld",&u,&v); tri s1=ask(u),t1=ask(v); int l1=lca(1,s1.v,t1.v); int s2=s1.u.u,t2=t1.u.u; ll l3=0; if(s1.v^l1){ int s3=ask(s1.v,d[1][s1.v]-d[1][l1]-1); l3+=d[0][s2]-d[0][s1.u.v]+c[s1.v]-c[s3]+1,s2=b[s3]; } if(t1.v^l1){ int t3=ask(t1.v,d[1][t1.v]-d[1][l1]-1); l3+=d[0][t2]-d[0][t1.u.v]+c[t1.v]-c[t3]+1,t2=b[t3]; } int l2=lca(0,s2,t2); l3+=d[0][s2]+d[0][t2]-d[0][l2]*2; printf("%lld\n",l3); } }