题目链接

https://www.lydsy.com/JudgeOnline/problem.php?id=3488

题解

考虑每条边对答案的贡献。设询问为llrr

  1. llrr不是祖先关系,那么能产生贡献的就是两端点分别在两个子树内的情况。
  2. llrr是祖先关系,假设llrr的祖先,那么能产生贡献的就是一个端点在rr的子树内,另一个端点不在llrr路径上第一个点的子树内的情况。

因此,将边转化成二维平面上的点,询问转化成二维平面上的矩形,离线处理并树状数组维护即可。

代码

#include <cmath>
#include <cstdio>
#include <algorithm>

int read()
{
  int x=0,f=1;
  char ch=getchar();
  while((ch<'0')||(ch>'9'))
    {
      if(ch=='-')
        {
          f=-f;
        }
      ch=getchar();
    }
  while((ch>='0')&&(ch<='9'))
    {
      x=x*10+ch-'0';
      ch=getchar();
    }
  return x*f;
}

const int maxn=100000;
const int maxq=500000;
const int maxm=maxn*2+maxq*3;

struct data
{
  int x,l,r,op,id;

  data(int _x=0,int _l=0,int _r=0,int _op=0,int _id=0):x(_x),l(_l),r(_r),op(_op),id(_id){}

  bool operator <(const data &other) const
  {
    if(x==other.x)
      {
        return abs(op)<abs(other.op);
      }
    return x<other.x;
  }
};

int n,m,q,tot,dfn[maxn+10],fa[20][maxn+10],cnt,pre[maxn*2+10],now[maxn+10],son[maxn*2+10],tote,deep[maxn+10],ans[maxq+10],size[maxn+10];
data d[maxm+10];

namespace st
{
  int v[maxn+10];

  int lowbit(int x)
  {
    return x&(-x);
  }

  int modify(int x,int val)
  {
    while(x<=n)
      {
        v[x]+=val;
        x+=lowbit(x);
      }
    return 0;
  }

  int getsum(int x)
  {
    int res=0;
    while(x)
      {
        res+=v[x];
        x-=lowbit(x);
      }
    return res;
  }
}

int ins(int a,int b)
{
  pre[++tote]=now[a];
  now[a]=tote;
  son[tote]=b;
  return 0;
}

int dfs(int u,int f)
{
  dfn[u]=++cnt;
  fa[0][u]=f;
  deep[u]=deep[f]+1;
  size[u]=1;
  for(int i=now[u]; i; i=pre[i])
    {
      int v=son[i];
      if(v!=f)
        {
          dfs(v,u);
          size[u]+=size[v];
        }
    }
  return 0;
}

int getfa()
{
  for(int k=1; k<=18; ++k)
    {
      for(int i=1; i<=n; ++i)
        {
          fa[k][i]=fa[k-1][fa[k-1][i]];
        }
    }
  return 0;
}

int getlca(int x,int y)
{
  if(deep[x]<deep[y])
    {
      std::swap(x,y);
    }
  for(int k=18; k>=0; --k)
    {
      if(deep[fa[k][x]]>=deep[y])
        {
          x=fa[k][x];
        }
    }
  if(x==y)
    {
      return y;
    }
  for(int k=18; k>=0; --k)
    {
      if(fa[k][x]!=fa[k][y])
        {
          x=fa[k][x];
          y=fa[k][y];
        }
    }
  return fa[0][y];
}

int getson(int a,int b)
{
  for(int k=18; k>=0; --k)
    {
      if(deep[fa[k][a]]>deep[b])
        {
          a=fa[k][a];
        }
    }
  return a;
}

int main()
{
  n=read();
  for(int i=1; i<n; ++i)
    {
      int a=read(),b=read();
      ins(a,b);
      ins(b,a);
    }
  dfs(1,0);
  getfa();
  m=read();
  for(int i=1; i<=m; ++i)
    {
      int a=read(),b=read();
      d[++tot]=data(dfn[a],dfn[b],0,0,i);
      d[++tot]=data(dfn[b],dfn[a],0,0,i);
    }
  q=read();
  for(int i=1; i<=q; ++i)
    {
      int a=read(),b=read(),lca=getlca(a,b);
      if(dfn[a]>dfn[b])
        {
          std::swap(a,b);
        }
      if(a==b)
        {
          ans[i]=m<<1;
        }
      else if(a==lca)
        {
          d[++tot]=data(n,dfn[b],dfn[b]+size[b]-1,1,i);
          int c=getson(b,a);
          d[++tot]=data(dfn[c]-1,dfn[b],dfn[b]+size[b]-1,1,i);
          d[++tot]=data(dfn[c]+size[c]-1,dfn[b],dfn[b]+size[b]-1,-1,i);
        }
      else
        {
          d[++tot]=data(dfn[a]+size[a]-1,dfn[b],dfn[b]+size[b]-1,1,i);
          d[++tot]=data(dfn[a]-1,dfn[b],dfn[b]+size[b]-1,-1,i);
        }
    }
  std::sort(d+1,d+tot+1);
  for(int i=1; i<=tot; ++i)
    {
      if(d[i].op==0)
        {
          st::modify(d[i].l,1);
        }
      else
        {
          ans[d[i].id]+=d[i].op*(st::getsum(d[i].r)-st::getsum(d[i].l-1));
        }
    }
  for(int i=1; i<=q; ++i)
    {
      printf("%d\n",ans[i]+1);
    }
  return 0;
}