bzoj 4539 [Hnoi2016]树——主席树+倍增

题目:https://www.lydsy.com/JudgeOnline/problem.php?id=4539

明明就是把每次复制的一个子树当作一个点,这样能连出一个树的结构,自己竟然都没想到。思维有待加强。

找编号为 k 的点,可以通过给 dfs 序建立对于编号的主席树。可以做一个 s[ i ] 表示第 i 步操作之后一共有多少个点,二分得知编号第 k 大的点在哪一步操作建出的大点里,然后用主席树查一下具体是哪个小点即可。每个大点记录一下自己的根,还有连向父亲中的哪个小点。

处理出每个小点在原树种的倍增数组和每个大点在新树中的倍增数组,就能查距离了。每个询问用刚才的二分和主席树找到询问的是哪个大点中的哪个小点。

在大点上倍增的时候先别跳最后一步,看看跳了之后是不是在同一个大点里,如果是,直接查询小点之间的距离即可。

数组开 n*17 好像有些不够。 n*20 可以。

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ls Ls[cr]
#define rs Rs[cr]
#define ll long long
using namespace std;
ll rdn()
{
  ll ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
const int N=1e5+5,K=17,M=2e6+5;//,M=N*K;
int n,m,Q,hd[N],xnt,to[N<<1],nxt[N<<1],pre[N][K+5],dep[N],siz[N];
int h2[N],xt2,t2[N<<1],nt2[N<<1],pr2[N][K+5],dp2[N];
int tim,dfn[N],rt[N],tot,Ls[M],Rs[M],sm[M];
ll s[N],dis[N][K+5]; int bin[K+5];
struct Node{
  int x,y;
  Node(int x=0,int y=0):x(x),y(y) {}
}a[N];
void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;}
void ad2(int x,int y){t2[++xt2]=y;nt2[xt2]=h2[x];h2[x]=xt2;}
void ins(int l,int r,int &cr,int pr,int k)
{
  cr=++tot; ls=Ls[pr]; rs=Rs[pr]; sm[cr]=sm[pr]+1;
  if(l==r)return; int mid=l+r>>1;
  if(k<=mid)ins(l,mid,ls,Ls[pr],k);
  else ins(mid+1,r,rs,Rs[pr],k);
}
int qry(int l,int r,int cr,int pr,int k)
{
  if(l==r)return l; int mid=l+r>>1;
  int s=sm[ls]-sm[Ls[pr]];
  if(s>=k)return qry(l,mid,ls,Ls[pr],k);
  else return qry(mid+1,r,rs,Rs[pr],k-s);
}
void ini_dfs(int cr,int fa)
{
  siz[cr]=1; dep[cr]=dep[fa]+1;
  dfn[cr]=++tim; ins(1,n,rt[tim],rt[tim-1],cr);
  pre[cr][0]=fa;
  for(int t=1,d;(d=pre[pre[cr][t-1]][t-1]);t++)
    pre[cr][t]=d;
  for(int i=hd[cr],v;i;i=nxt[i])
    if((v=to[i])!=fa) ini_dfs(v,cr),siz[cr]+=siz[v];
}
int Dis(int x,int y)
{
  if(!x||!y)return 0; int ret=0;//
  for(int t=K,lm=dep[x];t>=0;t--)
      if(dep[pre[y][t]]>=lm)
    ret+=bin[t], y=pre[y][t];
  return ret;
}
void Ini_dfs(int cr,int fa)
{
  dp2[cr]=dp2[fa]+1; pr2[cr][0]=fa;
  dis[cr][0]=Dis(a[fa].x,a[cr].y)+1;
  for(int t=1,d;pr2[d=pr2[cr][t-1]][t-1];t++)
    {
      pr2[cr][t]=pr2[d][t-1];
      dis[cr][t]=dis[cr][t-1]+dis[d][t-1];
    }
  for(int i=h2[cr],v;i;i=nt2[i])
    if((v=t2[i])!=fa) Ini_dfs(v,cr);
}
Node fnd(ll k,int r)
{
  int l=1,bh=0;Node ret;
  while(l<=r)
    {
      int mid=l+r>>1;
      if(s[mid]>=k)bh=mid,r=mid-1;
      else l=mid+1;
    }
  ret.x=bh; k-=s[bh-1]; bh=a[bh].x;
  ret.y=qry(1,n,rt[dfn[bh]+siz[bh]-1],rt[dfn[bh]-1],k);
  return ret;
}
int Dis2(int x,int y)
{
  int ret=0;
  if(dep[x]!=dep[y])
    {
      if(dep[x]<dep[y])swap(x,y);
      for(int t=K,lm=dep[y];t>=0;t--)
    if(dep[pre[x][t]]>=lm)
      ret+=bin[t], x=pre[x][t];
    }
  if(x!=y)
    {
      for(int t=K;t>=0;t--)
    if(pre[x][t]!=pre[y][t])
      { ret+=bin[t]<<1; x=pre[x][t]; y=pre[y][t];}
      ret+=2;//
    }
  return ret;
}
int main()
{
  n=rdn();m=rdn();Q=rdn(); ll u,v;
  for(int i=1;i<n;i++)
    u=rdn(),v=rdn(),add(u,v),add(v,u);
  bin[0]=1;for(int i=1;i<=K;i++)bin[i]=bin[i-1]<<1;
  ini_dfs(1,0); tot=0;s[1]=n;a[1]=Node(1,0); m++;
  for(int i=2;i<=m;i++)
    {
      u=rdn();v=rdn();s[i]=s[i-1]+siz[u];
      Node bh=fnd(v,i-1); a[i]=Node(u,bh.y);
      ad2(i,bh.x); ad2(bh.x,i);
    }
  Ini_dfs(1,0);
  for(int i=1;i<=Q;i++)
    {
      u=rdn(); v=rdn(); ll ans=0;
      Node x=fnd(u,m), y=fnd(v,m);//*.x:blk, *.y:point
      if(x.x==y.x){printf("%d\n",Dis2(x.y,y.y));continue;}
      if(dp2[x.x]!=dp2[y.x])
    {
      if(dp2[x.x]<dp2[y.x])swap(x,y);
      ans+=Dis(a[x.x].x,x.y);
      for(int t=K,lm=dp2[y.x]+1;t>=0;t--)
        if(dp2[pr2[x.x][t]]>=lm)
          ans+=dis[x.x][t], x.x=pr2[x.x][t];
      if(pr2[x.x][0]==y.x)
        {printf("%lld\n",ans+Dis2(y.y,a[x.x].y)+1);continue;}
      ans+=dis[x.x][0]; x.x=pr2[x.x][0]; x.y=a[x.x].x;
    }
      ans+=Dis(a[y.x].x,y.y); ans+=Dis(a[x.x].x,x.y);
      for(int t=K;t>=0;t--)
    if(pr2[x.x][t]!=pr2[y.x][t])
      {
        ans+=dis[x.x][t]+dis[y.x][t];
        x.x=pr2[x.x][t]; y.x=pr2[y.x][t];
      }
      printf("%lld\n",ans+Dis2(a[x.x].y,a[y.x].y)+2);
    }
  return 0;
}

 

posted on 2019-02-15 09:18  Narh  阅读(248)  评论(0编辑  收藏  举报

导航