BZOJ4539: [Hnoi2016]树

复制的树缩点,主席树查k小,毫无技术含量,纯码农题。

#include<bits/stdc++.h>
#define u first
#define v second
#define F lower_bound
#define I (i+j+2>>1)
#define J (i+j>>1)
using namespace std;
int n1,n2,m,n4;
typedef long long ll;
map<ll,int>nu;
const int N=1e5+5;
struct edge{
	int v;edge*s;
}z[N*2];
edge*a=z,*h[N];
void ins(int u,int v){
	edge s={v,h[u]};
	*(h[u]=a++)=s;
}
typedef int arr[N];
arr l,r,b,po,id,d[2],p[2][17];
ll n3,u,v,c[N];
void dfs(int u){
	r[id[l[u]=++n4]=u]=1;
	for(edge*i=h[u];i;i=i->s)
		if(i->v^p[0][0][u]){
			d[0][i->v]=d[0][p[0][0][i->v]=u]+1;
			dfs(i->v);
			r[u]+=r[i->v];
		}
}
typedef struct node*ptr;
struct node{
	ptr i,j;int s;
}e[N][17];
void ins(int i,int j,int s,ptr u,ptr v){
	while(i<j){
		*v=*u;
		if(s>J)u=u->j,v=v->j=v+1,i=J+1;
		else
			++v->s,u=u->i,v=v->i=v+1,j=I-1;
	}
}
int ask(int i,int j,int k,ptr u,ptr v){
	while(i<j){
		int s=v->s-u->s;
		if(k<=s)u=u->i,v=v->i,j=I-1;
		else
			k-=s,u=u->j,v=v->j,i=J+1;
	}
	return i;
}
int lca(int i,int s,int t){
	if(d[i][s]<d[i][t])swap(s,t);
	int k=d[i][s]-d[i][t];
	for(int j=16;~j;--j)
		if(k>>j&1)s=p[i][j][s];
	if(s==t)return s;
	for(int j=16;~j;--j)
		if(p[i][j][s]^p[i][j][t])
			s=p[i][j][s],t=p[i][j][t];
	return p[i][0][s];
}
typedef pair<int,int>vec;
typedef pair<vec,int>tri;
tri ask(ll v){
	typeof(nu.end())j=nu.F(v);
	int s=po[j->v];
	return tri(vec(ask(1,n1,v-j->u+r[s],e[l[s]-1],e[l[s]+r[s]-1]),s),j->v);
}
int ask(int s,int k){
	for(int j=16;~j;--j)
		if(k>>j&1)s=p[1][j][s];
	return s;
}
int main(){
	scanf("%d%d%d",&n1,&n2,&m),++n2;
	for(int i=2;i<=n1;++i)
		scanf("%lld%lld",&u,&v),ins(u,v),ins(v,u);
	dfs(po[nu[n3=n1]=1]=1);
	e[0][0]=(node){e[0],e[0]};
	for(int i=1;i<=n1;++i)
		ins(1,n1,id[i],e[i-1],e[i]);
	for(int i=2;i<=n2;++i){
		scanf("%lld%lld",&u,&v);
		tri s=ask(v);
		d[1][i]=d[1][p[1][0][i]=s.v]+1,c[i]=c[s.v]+d[0][b[i]=s.u.u]-d[0][s.u.v]+1,po[nu[n3+=r[u]]=i]=u;
	}
	for(int i=1;i<17;++i){
		for(int j=1;j<=n1;++j)
			p[0][i][j]=p[0][i-1][p[0][i-1][j]];
		for(int j=2;j<=n2;++j)
			p[1][i][j]=p[1][i-1][p[1][i-1][j]];
	}
	while(m--){
		scanf("%lld%lld",&u,&v);
		tri s1=ask(u),t1=ask(v);
		int l1=lca(1,s1.v,t1.v);
		int s2=s1.u.u,t2=t1.u.u;
		ll l3=0;
		if(s1.v^l1){
			int s3=ask(s1.v,d[1][s1.v]-d[1][l1]-1);
			l3+=d[0][s2]-d[0][s1.u.v]+c[s1.v]-c[s3]+1,s2=b[s3];
		}
		if(t1.v^l1){
			int t3=ask(t1.v,d[1][t1.v]-d[1][l1]-1);
			l3+=d[0][t2]-d[0][t1.u.v]+c[t1.v]-c[t3]+1,t2=b[t3];
		}
		int l2=lca(0,s2,t2);
		l3+=d[0][s2]+d[0][t2]-d[0][l2]*2;
		printf("%lld\n",l3);
	}
}
posted @ 2016-12-04 23:57  f321dd  阅读(119)  评论(0编辑  收藏  举报