[ BZOJ 3451 ] Normal

Description

题目链接

定义一次点分治的复杂度是所有分治中心分治时的子树大小之和。

给定一棵树,问所有点等概率被选做重心,点分治的期望复杂度。

Solution

根据期望的线性性,答案等价于每个点在点分树上的深度期望之和。

思路是从点对的角度考虑某一个点是否会产生贡献。

\[E(depth[x])=\sum_{y=1}^n P(x\in subtree[y]) \]

也就是 \(x\) 在点分树上在 \(1\dots n\) 的子树中的概率和。

考虑点分树上 \(y\)\(x\) 的祖先的条件,要求 \(x\)\(y\) 构成的这条链上第一个在点分治过程中被删除的点是 \(y\) ,由于链上被选中的概率相等,因此这个概率为 \(\frac{1}{dist(x,y) + 1}\)

所以答案为

\[\sum_{x=1}^n\sum_{j=1}^n \frac{1}{dis(i,j) + 1}=\sum_{len = 0}^n \frac{cnt[i]}{i + 1} \]

因此需要点分治求长度为 \(i\) 的路径条数 \(cnt[i]\) ,注意到合并的时候是卷积的形式。

容斥做法

不考虑重复路径,把子树 dfs 一遍,直接自己进行卷积,再去掉子树内重复计数的路径即可。

每一层最差以自己的 \(size\) 作为长度进行卷积,因此复杂度为 \(\mathcal O(n\log^2 n)\)

#include <cmath>
#include <cstdio>
#include <cctype>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define N 65537
#define mod 998244353
using namespace std;
typedef long long ll;
 
inline int rd() {
  int x = 0;
  char c = getchar();
  while (!isdigit(c)) c = getchar();
  while (isdigit(c)) {
    x = x * 10 + (c ^ 48); c = getchar();
  }
  return x;
}
 
inline void print(ll x) {
  int y = 10, len = 1;
  while(y <= x) {y *= 10; ++len;}
  while(len--) {y /= 10; putchar(x / y + 48); x %= y;}
  putchar('\n');
}
 
inline int fpow(int x, int t = mod - 2) {
  int res = 1;
  while (t) {
    if (t & 1) res = 1ll * res * x % mod;
    x = 1ll * x * x % mod; t >>= 1;
  }
  return res;
}
 
int mxlen = (1 << 16), w[2][N], rev[N];
 
inline int mo(int x) {
  return x >= mod ? x - mod : x;
}
 
inline void init() {
  int per = fpow(3, (mod - 1) / mxlen);
  int invper = fpow(per);
  w[0][0] = w[1][0] = 1;
  for (int i = 1; i < mxlen; ++i) {
    w[0][i] = 1ll * w[0][i - 1] * per % mod;
    w[1][i] = 1ll * w[1][i - 1] * invper % mod;
  }
}
 
inline int Rev(int n) {
  int len = 1, bit = 0;
  while (len <= n) len <<= 1, ++bit;
  for (int i = 0; i < len; ++i)
    rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)));
  return len;
}
 
inline void NTT(int *f, int len, int o) {
  for (int i = 0; i < len; ++i)
    if (i > rev[i]) swap(f[i], f[rev[i]]);
  for (int i = 1; i < len; i <<= 1) {
    int wn = mxlen / (i << 1);
    for (int j = 0; j < len; j += (i << 1)) {
      int nw = 0, x, y;
      for (int k = 0; k < i; ++k, nw += wn) {
        x = f[j + k];
        y = 1ll * w[o][nw] * f[i + j + k] % mod;
        f[j + k] = mo(x + y);
        f[i + j + k] = mo(x - y + mod);
      }
    }
  }
  if (o == 1) {
    int invl = fpow(len);
    for (int i = 0; i < len; ++i) f[i] = 1ll * f[i] * invl % mod;
  }
}
 
bool vis[N];
 
int n, m, tot, totn, mx, rt, mxd;
 
int bkt[N], cnt[N], sz[N], hd[N];
 
struct edge{int to, nxt;} e[N << 1];
 
inline void add(int u, int v) {
  e[++tot].to = v; e[tot].nxt = hd[u]; hd[u] = tot;
  e[++tot].to = u; e[tot].nxt = hd[v]; hd[v] = tot;
}
 
void getrt(int u, int fa) {
  sz[u] = 1;
  int mxs = 0;
  for (int i = hd[u], v; i; i = e[i].nxt)
    if ((v = e[i].to) != fa && !vis[v]) {
      getrt(v, u);
      sz[u] += sz[v];
      mxs = max(mxs, sz[v]);
    }
  mxs = max(mxs, totn - sz[u]);
  if (mxs < mx) {mx = mxs; rt = u;}
}
 
void getsz(int u, int fa) {
  sz[u] =  1;
  for (int i = hd[u], v; i; i = e[i].nxt)
    if ((v = e[i].to) != fa && !vis[v]) {
      getsz(v, u); sz[u] += sz[v];
    }
}
 
void dfs(int u, int fa, int dep) {
  ++bkt[dep]; mxd = max(mxd, dep);
  for (int i = hd[u], v; i; i = e[i].nxt)
    if ((v = e[i].to) != fa && !vis[v]) dfs(v, u, dep + 1);
}
 
inline void mul(int *a, int len, int o) {
  len = Rev(len << 1);
  NTT(a, len, 0);
  for (int i = 0; i < len; ++i) a[i] = 1ll * a[i] * a[i] % mod;
  NTT(a, len, 1);
  if (o > 0) for (int i = 0; i < len; ++i) cnt[i + 1] += a[i];
  else for (int i = 0; i < len; ++i) cnt[i + 3] -= a[i];
  for (int i = 0; i < len; ++i) a[i] = 0;
}
 
inline void calc(int u, int o) {
  mxd = 0;
  dfs(u, 0, 0);
  mul(bkt, mxd, o);
}
 
void divide(int u) {
  vis[u] = 1;
  calc(u, 1);
  for (int i = hd[u], v; i; i = e[i].nxt)
    if (!vis[v = e[i].to]) {
      calc(v, -1);
      getsz(v, u);
      totn = mx = sz[v]; rt = v;
      getrt(v, 0); divide(rt);
    }
}
 
int main() {
  init();
  n = rd();
  for (int i = 1; i < n; ++i) add(rd() + 1, rd() + 1);
  mx = totn = n;
  getrt(1, 0); divide(rt);
  double ans = 0.0;
  for (int i = 1; i <= n + 1; ++i) ans += (double) cnt[i] / i;
  printf("%.4lf", ans);
  return 0;
}

子树按秩合并做法

在点分治求路径条数时,我们尝试用按秩合并的思路去搞,也就是将子树按照最深深度排序,然后逐个合并计算答案。

开始的时候只有 \(bkt[0]=1\),然后按顺序卷每一个子树求出来的计数数组 \(bktson\)

把贡献直接计算,然后再将 \(bktson\) 按位加到 \(bkt\) 上。

考虑复杂度,将子树按照深度从小到大排序后,每次卷积得到的新的链长不会超过新合并的子树深度的二倍,所以每次卷积的数组长度为 \(mxdep[v]\) 的,且每个位置只会和其父节点卷积一次,因此总复杂度为 \(\mathcal O(n\log^2 n)\)

#include <cmath>
#include <cstdio>
#include <cctype>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define N 65537
#define mod 998244353
using namespace std;
typedef long long ll;
 
inline int rd() {
  int x = 0;
  char c = getchar();
  while (!isdigit(c)) c = getchar();
  while (isdigit(c)) {
    x = x * 10 + (c ^ 48); c = getchar();
  }
  return x;
}
 
inline void print(ll x) {
  int y = 10, len = 1;
  while(y <= x) {y *= 10; ++len;}
  while(len--) {y /= 10; putchar(x / y + 48); x %= y;}
  putchar('\n');
}
 
inline int fpow(int x, int t = mod - 2) {
  int res = 1;
  while (t) {
    if (t & 1) res = 1ll * res * x % mod;
    x = 1ll * x * x % mod; t >>= 1;
  }
  return res;
}
 
int mxlen = (1 << 16), w[2][N], rev[N];
 
inline int mo(int x) {
  return x >= mod ? x - mod : x;
}
 
inline void init() {
  int per = fpow(3, (mod - 1) / mxlen);
  int invper = fpow(per);
  w[0][0] = w[1][0] = 1;
  for (int i = 1; i < mxlen; ++i) {
    w[0][i] = 1ll * w[0][i - 1] * per % mod;
    w[1][i] = 1ll * w[1][i - 1] * invper % mod;
  }
}
 
inline int Rev(int n) {
  int len = 1, bit = 0;
  while (len <= n) len <<= 1, ++bit;
  for (int i = 0; i < len; ++i)
    rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)));
  return len;
}
 
inline void NTT(int *f, int len, int o) {
  for (int i = 0; i < len; ++i)
    if (i > rev[i]) swap(f[i], f[rev[i]]);
  for (int i = 1; i < len; i <<= 1) {
    int wn = mxlen / (i << 1);
    for (int j = 0; j < len; j += (i << 1)) {
      int nw = 0, x, y;
      for (int k = 0; k < i; ++k, nw += wn) {
        x = f[j + k];
        y = 1ll * w[o][nw] * f[i + j + k] % mod;
        f[j + k] = mo(x + y);
        f[i + j + k] = mo(x - y + mod);
      }
    }
  }
  if (o == 1) {
    int invl = fpow(len);
    for (int i = 0; i < len; ++i) f[i] = 1ll * f[i] * invl % mod;
  }
}
 
bool vis[N];
 
double ans = 0.0;
 
int n, m, tot, totn, mx, rt;
 
int bkt[N], sz[N], hd[N];
 
struct edge{int to, nxt;} e[N << 1];
 
inline void add(int u, int v) {
  e[++tot].to = v; e[tot].nxt = hd[u]; hd[u] = tot;
  e[++tot].to = u; e[tot].nxt = hd[v]; hd[v] = tot;
}
 
void getrt(int u, int fa) {
  sz[u] = 1;
  int mxs = 0;
  for (int i = hd[u], v; i; i = e[i].nxt)
    if ((v = e[i].to) != fa && !vis[v]) {
      getrt(v, u);
      sz[u] += sz[v];
      mxs = max(mxs, sz[v]);
    }
  mxs = max(mxs, totn - sz[u]);
  if (mxs < mx) {mx = mxs; rt = u;}
}
 
void getsz(int u, int fa) {
  sz[u] =  1;
  for (int i = hd[u], v; i; i = e[i].nxt)
    if ((v = e[i].to) != fa && !vis[v]) {
      getsz(v, u); sz[u] += sz[v];
    }
}
 
int res[N], tmp[N];
 
inline int mul(int *a, int *b, int lena, int lenb) {
  int len = Rev(lenb << 1);
  for (int i = 0; i < lena; ++i) res[i] = a[i];
  for (int i = lena; i < len; ++i) res[i] = 0;
  for (int i = 0; i < lenb; ++i) tmp[i] = b[i];
  for (int i = lenb; i < len; ++i) tmp[i] = 0;
  NTT(res, len, 0); NTT(tmp, len, 0);
  for (int i = 0; i < len; ++i) res[i] = 1ll * res[i] * tmp[i] % mod;
  NTT(res, len, 1);
  for (int i = 0; i < len; ++i) ans += 2.0 * res[i] / (i + 1);
  return len;
}
 
int mxd[N], s[N], bkts[N];
 
inline bool cmp(int x, int y) {return mxd[x] < mxd[y];}
 
int dfs(int u, int fa, int dep) {
  int resd = dep;
  for (int i = hd[u], v; i; i = e[i].nxt)
    if ((v = e[i].to) != fa && !vis[v]) resd = max(resd, dfs(v, u, dep + 1));
  return resd;
}
 
void dfs2(int u, int fa, int dep) {
  ++bkts[dep];
  for (int i = hd[u], v; i; i = e[i].nxt)
    if ((v = e[i].to) != fa && !vis[v]) dfs2(v, u, dep + 1);
}
 
void divide(int u) {
  vis[u] = 1;
  s[0] = 0;
  for (int i = hd[u], v; i; i = e[i].nxt)
    if (!vis[v = e[i].to]) {
      s[++s[0]] = v;
      mxd[v] = dfs(v, u, 1);
    }
  sort(s + 1, s + 1 + s[0], cmp);
  bkt[0] = 1;
  int nowlen = 1;
  for (int i = 1, v; i <= s[0]; ++i) {
    dfs2(v = s[i], 0, 1);
    nowlen = mul(bkt, bkts, nowlen, mxd[v] + 1);
    for (int i = 0; i <= mxd[v]; ++i) {
      bkt[i] += bkts[i]; bkts[i] = 0;
    }
  }
  for (int i = 0; i <= nowlen; ++i) bkt[i] = 0;
  for (int i = hd[u], v; i; i = e[i].nxt)
    if (!vis[v = e[i].to]) {
      getsz(v, u);
      totn = mx = sz[v]; rt = v;
      getrt(v, 0); divide(rt);
    }
}
 
int main() {
  init();
  n = rd();
  for (int i = 1; i < n; ++i) add(rd() + 1, rd() + 1);
  mx = totn = n;
  getrt(1, 0); divide(rt);
  printf("%.4lf", ans + n);
  return 0;
}
posted @ 2019-03-26 07:44  SGCollin  阅读(253)  评论(0编辑  收藏  举报