【BZOJ4539】树(HNOI2016)-主席树+LCA
测试地址:树
做法:本题需要用到主席树+LCA。
要求两点间的距离,显然要维护每个点的深度,以及要求两个点的LCA。
我们把一开始的树看成一块,然后每次操作,都是在某一块下面挂一个新的块,每个块都是模板树的一棵子树。这样我们可以先把块缩成点,那么缩块后整棵大树就变成了一棵更小的树。考虑求一个点的深度,这个深度等于它到它所在块的根的距离,加上块根到整棵树根节点的距离,显然前面的部分可以直接在模板树上求出,处理出深度后就可以询问,那我们又要维护块根在模板树中对应的点编号,可以直接在构造过程中维护,而后面的部分就可以直接在构造时维护。
求深度的问题解决了,现在要解决求LCA的问题了。注意到,两个点的LCA一定在它们所属块在缩块树上的LCA所对应的块中,于是我们倍增跳到LCA的下方,这时候我们需要知道这棵树挂在了哪个点上才能进入到LCA块中,于是我们在构造时维护块根上面的点在模板树中对应的点编号。那么最后我们就得到了在同一块中的两个点,直接在模板树中倍增求出LCA即可。这样一次询问的时间复杂度就是常数稍大的了。注意某些特殊情况,例如一开始两个点就在同一块中,或者其中一块是另一块的祖先。
上面这一通操作看上去完美,但实际上还有一个问题:因为点数可能达到,不能直接存储,那我们如何快速定位一个点在哪个块中,并且它在模板树中对应哪个点呢?注意到每块中节点的编号是连续的,因此我们可以二分定位该点所在块,时间复杂度为,而每一块中节点的编号顺序和模板树中编号顺序相同,因此要求该块中第个编号的点在模板树中对应的点编号,就是求在模板树的对应子树中第小的编号,我们知道树上的子树第小可以转化为DFS序上的区间第小,这就是主席树的经典应用了,于是我们做到了一次定位的时间复杂度。
那么我们就完成了这一题,时间复杂度为。
我傻逼的地方:漫长的四个小时告诉我们,永远都不要用相似的名字命名不同的东西……
以下是本人代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n,m,q,first[100010]={0},tot=0,fa[100010][21]={0};
int in[100010],out[100010],tim=0,pos[100010];
int rt[100010]={0},seg[2000010]={0},ch[2000010][2]={0};
int blocktop[100010],blockfa[100010][21]={0};
int blockup[100010]={0},blockdepth[100010];
ll dep[100010],siz[100010],blockl[100010],blockdep[100010],totsiz;
struct edge
{
int v,next;
}e[200010];
void insert(int a,int b)
{
e[++tot].v=b;
e[tot].next=first[a];
first[a]=tot;
}
void dfs(int v)
{
in[v]=++tim;
pos[tim]=v;
siz[v]=1;
for(int i=first[v];i;i=e[i].next)
if (e[i].v!=fa[v][0])
{
dep[e[i].v]=dep[v]+1;
fa[e[i].v][0]=v;
dfs(e[i].v);
siz[v]+=siz[e[i].v];
}
out[v]=tim;
}
void buildtree(int &v,int l,int r)
{
v=++tot;
if (l==r) return;
int mid=(l+r)>>1;
buildtree(ch[v][0],l,mid);
buildtree(ch[v][1],mid+1,r);
}
void add(int &v,int last,int l,int r,int x)
{
v=++tot;
seg[v]=seg[last];
ch[v][0]=ch[last][0];
ch[v][1]=ch[last][1];
if (l==r)
{
seg[v]++;
return;
}
int mid=(l+r)>>1;
if (x<=mid) add(ch[v][0],ch[last][0],l,mid,x);
else add(ch[v][1],ch[last][1],mid+1,r,x);
seg[v]=seg[ch[v][0]]+seg[ch[v][1]];
}
int findkth(int v,int last,int l,int r,int k)
{
if (l==r) return l;
int mid=(l+r)>>1;
if (seg[ch[v][0]]-seg[ch[last][0]]<k)
{
k-=seg[ch[v][0]]-seg[ch[last][0]];
return findkth(ch[v][1],ch[last][1],mid+1,r,k);
}
else return findkth(ch[v][0],ch[last][0],l,mid,k);
}
void init()
{
scanf("%d%d%d",&n,&m,&q);
for(int i=1;i<n;i++)
{
int a,b;
scanf("%d%d",&a,&b);
insert(a,b),insert(b,a);
}
dep[0]=-1;
dep[1]=0;
dfs(1);
for(int i=1;i<=18;i++)
for(int j=1;j<=n;j++)
fa[j][i]=fa[fa[j][i-1]][i-1];
tot=0;
buildtree(rt[0],1,n);
for(int i=1;i<=n;i++)
add(rt[i],rt[i-1],1,n,pos[i]);
}
int findblock(ll v,int limit)
{
int l=1,r=limit;
while(l<r)
{
int mid=(l+r)>>1;
if (v>=blockl[mid+1]) l=mid+1;
else r=mid;
}
return l;
}
int findpoint(int now,int k)
{
int p=blocktop[now];
return findkth(rt[out[p]],rt[in[p]-1],1,n,k);
}
void buildblock()
{
totsiz=n;
blockfa[1][0]=0;
blockdep[1]=0;
blocktop[1]=1;
blockl[1]=1;
blockup[1]=0;
blockdepth[1]=0;
blockdepth[0]=-1;
for(int i=2;i<=m+1;i++)
{
int a,now,p;
ll b,depth;
scanf("%d%lld",&a,&b);
now=findblock(b,i-1);
p=findpoint(now,b-blockl[now]+1);
depth=blockdep[now]+dep[p]-dep[blocktop[now]];
blockfa[i][0]=now;
blockdep[i]=depth+1;
blocktop[i]=a;
blockup[i]=p;
blockl[i]=totsiz+1;
blockdepth[i]=blockdepth[now]+1;
totsiz+=siz[a];
}
for(int i=1;i<=18;i++)
for(int j=1;j<=m+1;j++)
blockfa[j][i]=blockfa[blockfa[j][i-1]][i-1];
}
int lca(int a,int b)
{
if (dep[a]<dep[b]) swap(a,b);
for(int i=18;i>=0;i--)
if (dep[fa[a][i]]>=dep[b]) a=fa[a][i];
if (a==b) return a;
for(int i=18;i>=0;i--)
if (fa[a][i]!=fa[b][i])
a=fa[a][i],b=fa[b][i];
return fa[a][0];
}
void work()
{
for(int i=1;i<=q;i++)
{
ll a,b,ans=0;
int nowa,nowb,pa,pb;
int ansa,ansb,ansc;
scanf("%lld%lld",&a,&b);
nowa=findblock(a,m+1),nowb=findblock(b,m+1);
if (blockdepth[nowa]<blockdepth[nowb])
swap(nowa,nowb),swap(a,b);
int x=nowa,y=nowb;
for(int j=18;j>=0;j--)
if (blockdepth[blockfa[x][j]]>blockdepth[y])
x=blockfa[x][j];
if (x==y||blockfa[x][0]==y)
{
if (x==y) pa=findpoint(nowa,a-blockl[nowa]+1);
else pa=blockup[x];
pb=findpoint(nowb,b-blockl[nowb]+1);
if (x!=y) x=blockfa[x][0];
}
else
{
if (blockdepth[x]>blockdepth[y]) x=blockfa[x][0];
for(int j=18;j>=0;j--)
if (blockfa[x][j]!=blockfa[y][j])
x=blockfa[x][j],y=blockfa[y][j];
pa=blockup[x];
pb=blockup[y];
x=blockfa[x][0];
}
ansa=findpoint(nowa,a-blockl[nowa]+1);
ansb=findpoint(nowb,b-blockl[nowb]+1);
ansc=lca(pa,pb);
ans+=blockdep[nowa]+dep[ansa]-dep[blocktop[nowa]];
ans+=blockdep[nowb]+dep[ansb]-dep[blocktop[nowb]];
ans-=(blockdep[x]+dep[ansc]-dep[blocktop[x]])<<1;
printf("%lld\n",ans);
}
}
int main()
{
init();
buildblock();
work();
return 0;
}