BZOJ4539 [Hnoi2016]树 【倍增 + 主席树】

题目链接

BZOJ4539

题解

我们把每次复制出来的树看做一个点,那么大树实际上也就是一棵\(O(M)\)个点的树
所以我们只需求两遍树上距离:
大树上求距离,进入同一个点后在模板树上再求一次距离
讨论好一些情况即可

然后求子树第\(k\)大的点要用主席树

没了

#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<map>
#define Redge(u) for (int k = h[u],to; k; k = ed[k].nxt)
#define REP(i,n) for (int i = 1; i <= (n); i++)
#define mp(a,b) make_pair<int,int>(a,b)
#define cls(s) memset(s,0,sizeof(s))
#define cp pair<int,int>
#define LL long long int
using namespace std;
const int maxn = 100005,maxm = 6000005,INF = 1000000000;
inline LL read(){
	LL out = 0,flag = 1; char c = getchar();
	while (c < 48 || c > 57){if (c == '-') flag = -1; c = getchar();}
	while (c >= 48 && c <= 57){out = (out << 3) + (out << 1) + c - 48; c = getchar();}
	return out * flag;
}
LL S[maxn],R[maxn];
int n,m,Q,bin[50];
int h[maxn],ne = 1;
struct EDGE{int to,nxt;}ed[maxn << 1];
inline void build(int u,int v){
	ed[++ne] = (EDGE){v,h[u]}; h[u] = ne;
	ed[++ne] = (EDGE){u,h[v]}; h[v] = ne;
}
int sum[maxm],ls[maxm],rs[maxm],rt[maxn],tot;
void modify(int& u,int pre,int l,int r,int pos){
	sum[u = ++tot] = sum[pre] + 1;
	ls[u] = ls[pre]; rs[u] = rs[pre];
	if (l == r) return;
	int mid = l + r >> 1;
	if (mid >= pos) modify(ls[u],ls[pre],l,mid,pos);
	else modify(rs[u],rs[pre],mid + 1,r,pos);
}
int query(int u,int v,int l,int r,LL k){
	if (l == r) return l;
	int mid = l + r >> 1,t = sum[ls[u]] - sum[ls[v]];
	if (t >= k) return query(ls[u],ls[v],l,mid,k);
	return query(rs[u],rs[v],mid + 1,r,k - t);
}
int dfn[maxn],siz[maxn],dep[maxn],Fa[maxn][18],cnt;
void dfs(int u){
	dfn[u] = ++cnt; siz[u] = 1;
	modify(rt[cnt],rt[cnt - 1],1,n,u);
	REP(i,17) Fa[u][i] = Fa[Fa[u][i - 1]][i - 1];
	Redge(u) if ((to = ed[k].to) != Fa[u][0]){
		Fa[to][0] = u; dep[to] = dep[u] + 1;
		dfs(to);
		siz[u] += siz[to];
	}
}
int hh[maxn],nne = 1,Nxt[maxn << 1],To[maxn << 1];
int fa[maxn][18],N,lk[maxn],Dep[maxn];
LL d[maxn][18];
void DFS(int u){
	for (int k = hh[u]; k; k = Nxt[k]){
		Dep[To[k]] = Dep[u] + 1;
		DFS(To[k]);
	}
}
void Build(){
	LL u,b,x,r,v;
	S[1] = n; R[1] = 1; N = 1;
	for (int i = 1; i <= m; i++){
		u = read(); v = read();
		b = lower_bound(S + 1,S + 1 + N,v) - S; r = R[b];
		x = query(rt[dfn[r] + siz[r] - 1],rt[dfn[r] - 1],1,n,v - S[b - 1]);
		N++;
		fa[N][0] = b; d[N][0] = 1 + dep[x] - dep[r];
		S[N] = S[N - 1] + siz[u]; R[N] = u; lk[N] = x;
		nne++;
		Nxt[nne] = hh[b]; To[nne] = N; hh[b] = nne;
	}
	DFS(1);
	//REP(i,N) printf("block %d   rt = %lld lk = %d  d = %lld Dep = %d  total = %lld\n",i,R[i],lk[i],d[i][0],Dep[i],S[i]);
	REP(j,17) REP(i,N){
		fa[i][j] = fa[fa[i][j - 1]][j - 1];
		d[i][j] = d[i][j - 1] + d[fa[i][j - 1]][j - 1];
	}
}
LL dis(int u,int v){
	if (dep[u] < dep[v]) swap(u,v);
	LL re = 0;
	for (int i = 0,D = dep[u] - dep[v]; bin[i] <= D; i++)
		if (D & bin[i]) re += bin[i],u = Fa[u][i];
	if (u == v) return re;
	for (int i = 17; ~i; i--)
		if (Fa[u][i] != Fa[v][i]){
			u = Fa[u][i];
			v = Fa[v][i];
			re += bin[i + 1];
		}
	return re + 2;
}
LL Dis(LL a,LL b,LL x,LL y){
	if (Dep[a] < Dep[b]){
		swap(a,b);
		swap(x,y);
	}
	LL re = dep[x] - dep[R[a]] + dep[y] - dep[R[b]];
	int D = Dep[a] - Dep[b],u = a,v = b;
	for (int i = 0; bin[i] <= D; i++)
		if (D & bin[i]){
			re += d[u][i];
			u = fa[u][i];
		}
	if (u == v){
		D = Dep[a] - Dep[b] - 1; u = a;
		re = dep[x] - dep[R[a]];
		for (int i = 0; bin[i] <= D; i++)
			if (D & bin[i]){
				re += d[u][i];
				u = fa[u][i];
			}
		re += 1;
		x = lk[u];
		re += dis(x,y);
	}
	else {
		for (int i = 17; ~i; i--)
			if (fa[u][i] != fa[v][i]){
				re += d[u][i] + d[v][i];
				u = fa[u][i];
				v = fa[v][i];
			}
		re += 2;
		x = lk[u]; y = lk[v];
		re += dis(x,y);
	}
	return re;
}
void solve(){
	LL u,v,a,b,x,y;
	while (Q--){
		u = read(); v = read();
		a = lower_bound(S + 1,S + 1 + N,u) - S;
		b = lower_bound(S + 1,S + 1 + N,v) - S;
		x = query(rt[dfn[R[a]] + siz[R[a]] - 1],rt[dfn[R[a]] - 1],1,n,u - S[a - 1]);
		y = query(rt[dfn[R[b]] + siz[R[b]] - 1],rt[dfn[R[b]] - 1],1,n,v - S[b - 1]);
		//printf("(%lld,%lld)   (%lld,%lld)\n",a,x,b,y);
		if (a == b) printf("%lld\n",dis(x,y));
		else printf("%lld\n",Dis(a,b,x,y));
	}
}
int main(){
	bin[0] = 1; for (int i = 1; i <= 25; i++) bin[i] = bin[i - 1] << 1;
	n = read(); m = read(); Q = read();
	for (int i = 1; i < n; i++) build(read(),read());
	dfs(1);
	Build();
	solve();
	return 0;
}

posted @ 2018-05-29 09:24  Mychael  阅读(139)  评论(0编辑  收藏  举报