Loading

[ARC087D] Squirrel Migration 补题记录

题目链接

简要题意:
给你一个\(N\)个节点的树,求一个\(1\cdots N\)的排列\((p_1,p_2,\cdots p_N)\) ,使得\(\sum dist(i,p_i)\)最大。

求这样的排列的个数。答案对\(10^9+7\)取模。

分析

先考虑怎么构造出 \(\sum dist(i,p_i)\) 最大的 \(p\)

先取出一条边,把它断开,使得原树分成两个部分 \(S_1\)\(S_2\)
在最多的情况下,每一个都会走到另一个集合,所以路过切断边的次数是 \(2\times\min\{|S_1|,|S_2|\}\)

现在取出当前树的重心,至于为什么是重心:
考虑对于每一条边的贡献,及是之前讲的 \(2\times\min\{|S_1|,|S_2|\}\) ,我们要让贡献最大,显然要让 \(|S_1|\)\(|S_2|\) 尽可能平衡。
自然而然就可以想到重心。

我们设当前子树为 \(S\) ,对于子树内的节点 \(u\) 可以得到最大的构造方案是:

\[\forall u\in S,p_u\notin S \]

接下来考虑容斥求方案数。
我们设 \(f_i\) 表示有 \(i\) 个数不满足条件,其他随便选的方案数。
那么:

\[Ans = \sum_{i=0}^n(-1)^if_i(n-i)! \]

至于为什么有 \((n-i)!\) 是因为剩下的随便排列都是可以的。

我们考虑每一个与 \(\text{root}\) 连边的子树,设其大小为 \(\text{size}\)
可以得到这个子树中

\[f_i=C_{\text{size}}^i\prod_{j=size-i+1}^jj \]

也是比较好理解的,小学乘法原理即可。

最后背包一下,具体可看代码。

#include <cstdio>
#include <vector>
#include <cstring>
#include <iostream>
#include <algorithm>

#define file(a) freopen(a".in", "r", stdin), freopen(a".out", "w", stdout)

#define Enter putchar('\n')
#define quad putchar(' ')

#define int long long 
const int N = 5005;
const int mod = 1e9 + 7;

int n, siz[N], w[N], root, fac[N], f[N], ans, inv[N];
std::vector <int> dis[N];

inline int power(int a, int n);
inline int C(int n, int m);
inline void get_root(int now, int father);
inline void dfs(int now, int father);

signed main(void) {
  // file("AT3728");
  f[0] = 1; fac[0] = 1; inv[0] = 1;
  std::cin >> n;
  for (int i = 1; i <= n; i++)
    fac[i] = fac[i - 1] * i % mod;
  inv[n] = power(fac[n], mod - 2);
  for (int i = n - 1; i >= 1; i--)
    inv[i] = inv[i + 1] * (i + 1) % mod;
  for (int i = 1, x, y; i < n; i++) {
    scanf("%lld %lld", &x, &y);
    dis[x].emplace_back(y);
    dis[y].emplace_back(x);
  }
  get_root(1, 0);
  memset(siz, 0, sizeof(siz));
  dfs(root, 0);
  for (int t : dis[root]) {
    int x = siz[t];
    for (int j = n; j >= 0; j--) {
      for (int k = 1; k <= std::min(j, x); k++) {
        int mul = C(x, k) * fac[x] % mod * inv[x - k] % mod;
        f[j] = (f[j] + f[j - k] * mul % mod) % mod;
      }
    }
  }
  for (int i = 0; i <= n; i++) {
    int flag = 1, num;
    if (i % 2 == 1) flag = -1;
    num = f[i] * fac[n - i] % mod;
    ans = ans + flag * num;
    ans = (ans % mod + mod) % mod;
  }
  std::cout << ans << std::endl;
  return 0;
}

inline int power(int a, int n) {
  int ret = 1;
  while (n) {
    if (n & 1) ret = ret * a % mod;
    a = a * a % mod;
    n /= 2;
  }
  return ret;
}
inline int C(int n, int m) {
  if (n < m) return 0;
  int ret = fac[n];
  ret = ret * inv[m] % mod;
  ret = ret * inv[n - m] % mod;
  return ret;
}
inline void get_root(int now, int father) {
  siz[now] = 1; w[now] = 0;
  for (int t : dis[now]) {
    if (t == father) continue;
    get_root(t, now);
    siz[now] += siz[t];
    w[now] = std::max(w[now], siz[t]);
  }
  w[now] = std::max(w[now], n - siz[now]);
  if (w[now] <= n / 2) root = now;
}
inline void dfs(int now, int father) {
  siz[now] = 1;
  for (int t : dis[now]) {
    if (t == father) continue;
    dfs(t, now);
    siz[now] += siz[t];
  }
}
posted @ 2022-06-06 16:22  Aonynation  阅读(37)  评论(0编辑  收藏  举报