[BZOJ3451]normal 点分治,NTT
[BZOJ3451]normal 点分治,NTT
好久没更博了,咕咕咕。
BZOJ3451权限题,上darkbzoj交吧。
一句话题意,求随机点分治的期望复杂度。
考虑计算每个点对的贡献:如果一个点在点分树上是另一个点的祖先,那么这个点对另一个点的贡献就是1,这样的话,这个点就必须是这两个点之间的链上的点中在点分树上深度最浅的点,由于链上每个点成为点分树上最浅的点的概率都是相等的,所以这个点对对最终的期望的贡献就是\(\frac{1}{dis(i, j) + 1}\),这里的\(dis(i, j)\)习惯上认为是边的条数,\(+1\)就变成了点的个数。现在我们要求的就是\(\sum \limits _{i = 1} ^{n} \sum \limits _{j = 1} ^{n} \frac {1} {dis(i, j) + 1}\)。
考虑怎样在树上统计,明显需要点分治,每次统计从当前分治中心出发的所有长度的路径条数,跟先前统计过的子树合并一下就好了。观察合并的式子,设\(f[i]\)表示这次合并统计的长度为\(i\)的路径条数,\(p[i]\)表示以前统计过的子树中长度为\(i\)的路径条数,\(q[i]\)表示这次统计的长度为\(i\)的路径条数,显然有\(f[k] = \sum \limits _{i + j = k} p[i] * q[j]\),明显是一个卷积的形式,可以FFT,但是发现\(f\)数组的值肯定不会超过NTT模数,于是直接NTT。时间复杂度\(O(n (\log n) ^ 2)\)。
实现的时候要注意,每次从分支中心出发统计时,必须先把子树按深度或大小排一遍序,从小往大处理,不然一个扫把图就能把你卡成\(n ^ 2 \log n\),可以自己画图手玩。
#include <cstdio>
#include <cctype>
#include <vector>
#include <cstring>
#include <algorithm>
#define R register
#define I inline
#define D double
#define L long long
#define B 1000000
using namespace std;
const int N = 32777, P = 998244353, G = 3, H = 332748118;
char buf[B], *p1, *p2;
I char gc() { return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, B, stdin), p1 == p2) ? EOF : *p1++; }
I int rd() {
R int f = 0;
R char c = gc();
while (c < 48 || c > 57)
c = gc();
while (c > 47 && c < 58)
f = f * 10 + (c ^ 48), c = gc();
return f;
}
int s[N], t[N], v[N], p[N], q[N], a[N], b[N], d[N], c[N], h[N], u, r, S, e, F, M;
D o;
vector <int> g[N];
I int max(int x, int y) { return x > y ? x : y; }
I void swp(int &x, int &y) { x ^= y, y ^= x, x ^= y; }
I int cmp(int x, int y) { return t[x] < t[y]; }
void gsz(int x, int f) {
t[x] = 1;
for (R int i = 0, y; i < s[x]; ++i)
if (!v[y = g[x][i]] && y ^ f)
gsz(y, x), t[x] += t[y];
}
void grt(int x, int f, int a) {
R int m = 0, i, y;
for (i = 0; i < s[x]; ++i)
if (!v[y = g[x][i]] && y ^ f)
m = max(m, t[y]), grt(y, x, a);
m = max(m, a - t[x]);
if (m < u)
u = m, r = x;
}
void dfs(int x, int f, int d) {
c[++e] = d;
for (R int i = 0, y; i < s[x]; ++i)
if (!v[y = g[x][i]] && y ^ f)
dfs(y, x, d + 1);
}
I L pwr(L a, L b) {
L r = 1;
for (; b; b >>= 1, a = a * a % P)
if (b & 1)
r = r * a % P;
return r;
}
void ntt(int *f, int v) {
R int i, j, k, t;
L p, q, o;
for (i = 0; i < M; ++i)
if (i < d[i])
swp(f[i], f[d[i]]);
for (i = 1; i < M; i <<= 1) {
t = i << 1, p = pwr(v ? G : H, (P - 1) / t);
for (j = 0; j < M; j += t)
for (q = 1, k = 0; k < i; ++k)
o = q * f[i + j + k] % P, f[i + j + k] = (f[j + k] - o + P) % P, f[j + k] = (f[j + k] + o + P) % P, q = q * p % P;
}
}
void dac(int x) {
R int i, j, y, z;
p[0] = 1, h[0] = 0, u = S, gsz(x, 0), grt(x, 0, t[x]), v[r] = 1, sort(&g[r][0], s[r] + &g[r][0], cmp);
for (i = 0; i < s[r]; ++i)
if (!v[y = g[r][i]]) {
for (M = 1; M <= t[x]; M <<= 1) {}
e = 0, dfs(y, r, 1), F = pwr(M, P - 2), d[0] = 0;
for (z = M >> 1, j = 0; j < M; ++j)
d[j] = (d[j >> 1] >> 1)|((j & 1) ? z : 0);
for (j = 1; j <= e; ++j)
++q[c[j]], h[++h[0]] = c[j];
memcpy(a, p, M * 4), memcpy(b, q, M * 4);
ntt(a, 1), ntt(b, 1);
for (j = 0; j < M; ++j)
a[j] = 1ll * a[j] * b[j] % P;
ntt(a, 0);
for (j = 0; j < M; ++j)
p[j] += q[j], a[j] = 1ll * a[j] * F * 2 % P, o += (D)a[j] / (j + 1);
for (j = 1; j <= e; ++j)
q[c[j]] = 0;
}
for (i = 1; i <= h[0]; ++i)
p[h[i]] = 0;
for (x = r, i = 0; i < s[x]; ++i)
if(!v[y = g[x][i]])
dac(y);
}
int main() {
R int n = rd(), i, x, y;
for (S = 1; S <= n; S <<= 1) {}
for (i = 1; i < n; ++i)
x = rd() + 1, y = rd() + 1, g[x].push_back(y), g[y].push_back(x);
for (i = 1; i <= n; ++i)
s[i] = g[i].size();
o = n, dac(1), printf("%.4lf", o);
return 0;
}
码风也变了。。。