题目链接
https://www.lydsy.com/JudgeOnline/problem.php?id=3488
题解
考虑每条边对答案的贡献。设询问为到。
- 若和不是祖先关系,那么能产生贡献的就是两端点分别在两个子树内的情况。
- 若和是祖先关系,假设是的祖先,那么能产生贡献的就是一个端点在的子树内,另一个端点不在到路径上第一个点的子树内的情况。
因此,将边转化成二维平面上的点,询问转化成二维平面上的矩形,离线处理并树状数组维护即可。
代码
#include <cmath>
#include <cstdio>
#include <algorithm>
int read()
{
int x=0,f=1;
char ch=getchar();
while((ch<'0')||(ch>'9'))
{
if(ch=='-')
{
f=-f;
}
ch=getchar();
}
while((ch>='0')&&(ch<='9'))
{
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
const int maxn=100000;
const int maxq=500000;
const int maxm=maxn*2+maxq*3;
struct data
{
int x,l,r,op,id;
data(int _x=0,int _l=0,int _r=0,int _op=0,int _id=0):x(_x),l(_l),r(_r),op(_op),id(_id){}
bool operator <(const data &other) const
{
if(x==other.x)
{
return abs(op)<abs(other.op);
}
return x<other.x;
}
};
int n,m,q,tot,dfn[maxn+10],fa[20][maxn+10],cnt,pre[maxn*2+10],now[maxn+10],son[maxn*2+10],tote,deep[maxn+10],ans[maxq+10],size[maxn+10];
data d[maxm+10];
namespace st
{
int v[maxn+10];
int lowbit(int x)
{
return x&(-x);
}
int modify(int x,int val)
{
while(x<=n)
{
v[x]+=val;
x+=lowbit(x);
}
return 0;
}
int getsum(int x)
{
int res=0;
while(x)
{
res+=v[x];
x-=lowbit(x);
}
return res;
}
}
int ins(int a,int b)
{
pre[++tote]=now[a];
now[a]=tote;
son[tote]=b;
return 0;
}
int dfs(int u,int f)
{
dfn[u]=++cnt;
fa[0][u]=f;
deep[u]=deep[f]+1;
size[u]=1;
for(int i=now[u]; i; i=pre[i])
{
int v=son[i];
if(v!=f)
{
dfs(v,u);
size[u]+=size[v];
}
}
return 0;
}
int getfa()
{
for(int k=1; k<=18; ++k)
{
for(int i=1; i<=n; ++i)
{
fa[k][i]=fa[k-1][fa[k-1][i]];
}
}
return 0;
}
int getlca(int x,int y)
{
if(deep[x]<deep[y])
{
std::swap(x,y);
}
for(int k=18; k>=0; --k)
{
if(deep[fa[k][x]]>=deep[y])
{
x=fa[k][x];
}
}
if(x==y)
{
return y;
}
for(int k=18; k>=0; --k)
{
if(fa[k][x]!=fa[k][y])
{
x=fa[k][x];
y=fa[k][y];
}
}
return fa[0][y];
}
int getson(int a,int b)
{
for(int k=18; k>=0; --k)
{
if(deep[fa[k][a]]>deep[b])
{
a=fa[k][a];
}
}
return a;
}
int main()
{
n=read();
for(int i=1; i<n; ++i)
{
int a=read(),b=read();
ins(a,b);
ins(b,a);
}
dfs(1,0);
getfa();
m=read();
for(int i=1; i<=m; ++i)
{
int a=read(),b=read();
d[++tot]=data(dfn[a],dfn[b],0,0,i);
d[++tot]=data(dfn[b],dfn[a],0,0,i);
}
q=read();
for(int i=1; i<=q; ++i)
{
int a=read(),b=read(),lca=getlca(a,b);
if(dfn[a]>dfn[b])
{
std::swap(a,b);
}
if(a==b)
{
ans[i]=m<<1;
}
else if(a==lca)
{
d[++tot]=data(n,dfn[b],dfn[b]+size[b]-1,1,i);
int c=getson(b,a);
d[++tot]=data(dfn[c]-1,dfn[b],dfn[b]+size[b]-1,1,i);
d[++tot]=data(dfn[c]+size[c]-1,dfn[b],dfn[b]+size[b]-1,-1,i);
}
else
{
d[++tot]=data(dfn[a]+size[a]-1,dfn[b],dfn[b]+size[b]-1,1,i);
d[++tot]=data(dfn[a]-1,dfn[b],dfn[b]+size[b]-1,-1,i);
}
}
std::sort(d+1,d+tot+1);
for(int i=1; i<=tot; ++i)
{
if(d[i].op==0)
{
st::modify(d[i].l,1);
}
else
{
ans[d[i].id]+=d[i].op*(st::getsum(d[i].r)-st::getsum(d[i].l-1));
}
}
for(int i=1; i<=q; ++i)
{
printf("%d\n",ans[i]+1);
}
return 0;
}