CodeForces 805F Expected diameter of a tree 期望
题意:
给出一个森林,有若干询问\(u, v\):
从\(u, v\)中所在子树中随机各选一个点连起来,构成一棵新树,求新树直径的期望。
分析:
回顾一下和树的直径有关的东西:
求树的直径
从树的任意一点出发搜到最远的一点\(x\),再从\(x\)出发搜到距\(x\)最远的一点\(y\),那么\(d(x,y)\)就是树的直径。时间复杂度为\(O(n)\)。
求构成新树的直径
假设原来两棵树的直径分别问\(d_1,d_2\)
令\(f_i\)为点\(i\)所在子树中距它最远的点的距离
新树的直径要么在原来两棵树中\(max(d_1,d_2)\),要么经过添加的边\(u \to v\)为\(f_u + f_v + 1\)
新的直径为两种情况取最大值
计算\(f_i\)
对于每个点\(i\)计算出距它最远的距离,只要分别从直径的两端各\(DFS\)一次即可,保存最大值。
也就是说,距离\(i\)最远的点是直径两个端点其中之一。
处理询问
只用考虑询问两点在不同子树中的情况:
枚举一棵子树中的\(f_u\),对另一棵树中的\(f\)排序。
二分或者尺取出\(f_u + f_v + 1 \leq max(d_1, d_2)\)的个数,分别统计出答案。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <map>
#include <set>
#include <vector>
#include <iostream>
#include <string>
using namespace std;
#define REP(i, a, b) for(int i = a; i < b; i++)
#define PER(i, a, b) for(int i = b - 1; i >= a; i--)
#define SZ(a) ((int)a.size())
#define MP make_pair
#define PB push_back
#define EB emplace_back
#define ALL(a) a.begin(), a.end()
typedef long long LL;
typedef pair<int, int> PII;
const int maxn = 100000 + 10;
int n, m, q;
vector<int> G[maxn], dis[maxn];
vector<LL> pre[maxn];
int f[maxn], d[maxn], cnt;
map<PII, LL> ans;
int pa[maxn], sz[maxn];
int findset(int x) { return x == pa[x] ? x : pa[x] = findset(pa[x]); }
void Union(int x, int y) {
int px = findset(x), py = findset(y);
if(px != py) {
pa[px] = py;
sz[py] += sz[px];
}
}
int id, depth, root, flag;
void dfs(int u, int fa = -1, int h = 0) {
if(h > depth) { depth = h; id = u; }
if(f[u] < h) f[u] = h;
if(flag) dis[root].PB(f[u]);
for(int v : G[u]) if(v != fa) dfs(v, u, h + 1);
}
int main() {
scanf("%d%d%d", &n, &m, &q);
REP(i, 0, n) pa[i] = i, sz[i] = 1;
while(m--) {
int u, v; scanf("%d%d", &u, &v);
u--; v--;
Union(u, v);
G[u].push_back(v);
G[v].push_back(u);
}
REP(i, 0, n) if(i == pa[i]) {
flag = false; root = i;
depth = 0; id = i; dfs(i);
depth = 0; dfs(id); d[i] = depth;
flag = true; dfs(id);
sort(ALL(dis[i]));
pre[i].resize(sz[i] + 1);
pre[i][0] = 0;
REP(j, 0, SZ(dis[i])) pre[i][j + 1] = pre[i][j] + dis[i][j];
}
while(q--) {
int u, v; scanf("%d%d", &u, &v);
u--; v--;
u = findset(u), v = findset(v);
if(u == v) { printf("-1\n"); continue; }
if(sz[u] > sz[v] || (sz[u] == sz[v] && u > v)) swap(u, v);
if(ans.count(MP(u, v))) { printf("%.10f\n", (double)ans[MP(u, v)] / sz[u] / sz[v]); continue; }
int maxd = max(d[u], d[v]);
int p = sz[v] - 1;
LL t = 0;
for(int x : dis[u]) {
while(p >= 0 && x + dis[v][p] + 1 > maxd) p--;
t += (LL)maxd * (p+1) + (LL)(x+1)*(sz[v]-1-p) + pre[v].back()-pre[v][p+1];
}
ans[MP(u, v)] = t;
printf("%.10f\n", (double)t / sz[u] / sz[v]);
}
return 0;
}