题目大意 : 有一棵大小为 $n$ 的树,$m$ 次询问,每一次询问给出点对$(a, b)$ ,求树上到 $a, b $ 距离相同的点有多少个
$1≤ n, m≤ 10^5$
[$>Codeforces \space 519 E. A and B and Lecture Rooms<$](http://codeforces.com/contest/519/problem/E)
题目大意 : 有一棵大小为 \(n\) 的树,\(m\) 次询问,每一次询问给出点对\((a, b)\) ,求树上到 $a, b $ 距离相同的点有多少个
\(1≤ n, m≤ 10^5\)
解题思路 :
观察发现,所有满足条件的点到 \(a, b\) 的路径都包含路径 \((a, b)\) 的中点 \(mid\)
证明:如果有一点 \(x\) 要满足到 \(a, b\) 距离相等,如果 \(x\) 处于 $a $ 的子树且 \(dep_a\geq dep_b\).
那么有 \(dis(x, b) > dis(x, a) \ b\)的情况也同理,所以 \(x\) 到 \(a, b\) 的路径必然与路径 \((a, b)\) 相交
假设 \(x\) 到\(a, b\) 中的一点的路径不包含 \(mid\),设 \(c\) 为\((x, a)\ (x, b)\) 在路径 \((a, b)\) 上的分界点,因为路径有一点不包含 \(mid\),所以 \(c \neq mid\) ,\(dis(c, a) \neq dis(c, b)\)
因此可以推导出 \(dis(x, c) + dis(c, a) \neq dis(x, c) + dis(c, b)\),
所以 \(dis(x, a) \neq dis(x, b)\) 所以满足条件的点到\(a, b\) 的路径必然都包含 \(mid\)
可以先求出 \(Lca\) ,如果路径上的点数为偶数那么答案就是 \(0\)
否则通过倍增求出中点 \(mid\), 如果 \(mid\) 在路径 \((a, lca)\) 上,\(a \neq lca\) ,那么答案就是 \(sz_{mid} - sz_x\) 其中 \(x\) 是 \(mid\) 在路上的儿子,当 \(mid\) 在路径 \((b, lca)\) 上的情况同理
如果 \(mid\) 在 \(Lca\) 上,那么答案就是 \(n - sz_{x} - sz_{y}\), 其中\(x, y\) 分别是 \(Lca\) 在路径上的儿子
```cpp
/*program by mangoyang*/
#include
#define inf (0x7f7f7f7f)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template
inline void read(T &x){
int f = 0, ch = 0; x = 0;
for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
if(f) x = -x;
}
#define N (1000005)
#define fi first
#define se second
int a[N], nxt[N], head[N], cnt;
int sz[N], dep[N], f[N][24], n, m;
inline void add(int x, int y){
a[++cnt] = y, nxt[cnt] = head[x], head[x] = cnt;
}
inline void dfs(int u, int fa){
sz[u] = 1, dep[u] = dep[fa] + 1, f[u][0] = fa;
for(int p = head[u]; p; p = nxt[p]){
int v = a[p];
if(v != fa) dfs(v, u), sz[u] += sz[v];
}
}
inline int Lca(int x, int y){
if(dep[x] < dep[y]) swap(x, y);
for(int i = 22; i >= 0; i--){
if(dep[f[x][i]] >= dep[y]) x = f[x][i];
if(dep[x] == dep[y]) break;
}
if(x == y) return x;
for(int i = 22; i >= 0; i--)
if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
inline pair pp(int x, int y){
if(dep[x] < dep[y]) swap(x, y);
for(int i = 22; i >= 0; i--){
if(dep[f[x][i]] >= dep[y]) x = f[x][i];
if(dep[x] == dep[y]) break;
}
for(int i = 22; i >= 0; i--)
if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return make_pair(x, y);
}
int main(){
read(n);
for(int i = 1, x, y; i < n; i++)
read(x), read(y), add(x, y), add(y, x);
dfs(1, 0);
for(int j = 1; j <= 22; j++)
for(int i = 1; i <= n; i++) f[i][j] = f[f[i][j-1]][j-1];
read(m);
while(m--){
int x, y; read(x), read(y);
if(x == y){ printf("%d\n", n); continue; }
int lca = Lca(x, y), dis = dep[x] + dep[y] - 2 * dep[lca];
if(dis & 1){ puts("0"); continue; };
if(dep[x] < dep[y]) swap(x, y); int u = x;
for(int i = 22; i >= 0; i--)
if(dep[x] - dep[f[u][i]] < dis / 2 && dep[f[u][i]] > dep[lca]) u = f[u][i];
if(dep[x] - dep[f[u][0]] == dis / 2){
if(f[u][0] == lca){
pair p = pp(x, y);
printf("%d\n", n - sz[p.fi] - sz[p.se] );
}
else printf("%d\n", sz[f[u][0]] - sz[u]);
}
else puts("0");
}
return 0;
}