题目描述
输入
输出
样例输入
4
1 2
1 3
2 4
1
2 3
样例输出
1
求lca,但是要用树上倍增来求,if(dis&1) return 0;当距离为奇数时,没有地点满足要求,如果lca到两点的距离刚好相等
![]()
ans=n-sz[fx]-sz[fy]
如果到lca的距离不相等,假设x为深度较大的,那么x需要往上爬dis/2-1个深度,此时爬到的父节点记为z,ans=sz[fa[z]]-sz[z]
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define maxn 100005
using namespace std;
inline int read()
{
int x=0;char ch=getchar();
while(ch<'0'||ch>'9') ch=getchar();
while(ch>='0'&&ch<='9') {x=x*10+ch-'0';ch=getchar();}
return x;
}
int n,Maxdeep,m;
struct edge
{
int to,ne;
}b[maxn*4];
int k=0,head[maxn];
void swap(int &x,int &y) { int z=x;x=y;y=z; }
void add(int u,int v)
{
k++;
b[k].to=v;b[k].ne=head[u];head[u]=k;
}
int f[maxn][35],d[maxn],sz[maxn],fa[maxn];
void dfs(int x)
{
sz[x]=1;
for(int i=head[x];i!=-1;i=b[i].ne)
if(b[i].to!=fa[x]){
fa[b[i].to]=x; d[b[i].to]=d[x]+1;
dfs(b[i].to);
sz[x]+=sz[b[i].to];
}
}
void init()
{
memset(f,-1,sizeof(f));
for(int i=1;i<=n;i++) f[i][0]=fa[i];
for(int j=1;(1<<j)<=n;j++)
for(int i=1;i<=n;i++)
if(f[i][j-1]!=1) f[i][j]=f[f[i][j-1]][j-1];
}
int lca(int x,int y)
{
int ti=0;
for(ti=0;(1<<ti)<=d[x];ti++); ti--;
for(int i=ti;i>=0;i--)
if(d[x]-(1<<i)>=d[y]) x=f[x][i];
if(x==y) return x;
for(int i=ti;i>=0;i--)
if(f[x][i]!=-1&&f[x][i]!=f[y][i]){
x=f[x][i];
y=f[y][i];
}
return fa[x];
}
int quiry(int x,int y)
{
int ans=0;
if(d[x]<d[y]) swap(x,y);
int anc=lca(x,y);
int dis=d[x]+d[y]-2*d[anc];
if((dis&1)) return 0;
else{
if(d[x]==d[y]){
int t=d[anc];
ans=n;
int op=0;
for(int i=head[anc];i!=-1;i=b[i].ne)
if(d[b[i].to]==t+1){
int ff=b[i].to;
if(lca(x,ff)==ff){ ans-=sz[ff]; op++; }
if(lca(y,ff)==ff){ ans-=sz[ff]; op++; }
if(op==2) break;
}
return ans;
}
else{
int mid=dis/2;
int t=mid-1;
for(int i=0;t;i++)
if(t&(1<<i)){
t^=(1<<i);
x=f[x][i];
}
ans=sz[fa[x]]-sz[x];
return ans;
}
}
}
int main()
{
memset(head,-1,sizeof(head));
n=read();
int x,y;
for(int i=1;i<n;i++){
x=read(); y=read();
add(x,y);add(y,x);
}
dfs(1); init();
m=read();
if(Maxdeep==n-1){
for(int i=1;i<=m;i++){
x=read();y=read();
if((abs(d[x]-d[y])&1)) printf("0\n");
else{
if(x==y) printf("%d\n",n);
else printf("1\n");
}
}
return 0;
}
for(int i=1;i<=m;i++){
x=read();y=read();
if(x==0||y==0){printf("1\n");continue;}
if(x==y) printf("%d\n",n);
else printf("%d\n",quiry(x,y));
}
return 0;
}