[ARC087D] Squirrel Migration 补题记录
题目链接
简要题意:
给你一个\(N\)个节点的树,求一个\(1\cdots N\)的排列\((p_1,p_2,\cdots p_N)\) ,使得\(\sum dist(i,p_i)\)最大。
求这样的排列的个数。答案对\(10^9+7\)取模。
分析
先考虑怎么构造出 \(\sum dist(i,p_i)\) 最大的 \(p\) 。
先取出一条边,把它断开,使得原树分成两个部分 \(S_1\) 和 \(S_2\) 。
在最多的情况下,每一个都会走到另一个集合,所以路过切断边的次数是 \(2\times\min\{|S_1|,|S_2|\}\) 。
现在取出当前树的重心,至于为什么是重心:
考虑对于每一条边的贡献,及是之前讲的 \(2\times\min\{|S_1|,|S_2|\}\) ,我们要让贡献最大,显然要让 \(|S_1|\) 和 \(|S_2|\) 尽可能平衡。
自然而然就可以想到重心。
我们设当前子树为 \(S\) ,对于子树内的节点 \(u\) 可以得到最大的构造方案是:
\[\forall u\in S,p_u\notin S
\]
接下来考虑容斥求方案数。
我们设 \(f_i\) 表示有 \(i\) 个数不满足条件,其他随便选的方案数。
那么:
\[Ans = \sum_{i=0}^n(-1)^if_i(n-i)!
\]
至于为什么有 \((n-i)!\) 是因为剩下的随便排列都是可以的。
我们考虑每一个与 \(\text{root}\) 连边的子树,设其大小为 \(\text{size}\) 。
可以得到这个子树中
\[f_i=C_{\text{size}}^i\prod_{j=size-i+1}^jj
\]
也是比较好理解的,小学乘法原理即可。
最后背包一下,具体可看代码。
#include <cstdio>
#include <vector>
#include <cstring>
#include <iostream>
#include <algorithm>
#define file(a) freopen(a".in", "r", stdin), freopen(a".out", "w", stdout)
#define Enter putchar('\n')
#define quad putchar(' ')
#define int long long
const int N = 5005;
const int mod = 1e9 + 7;
int n, siz[N], w[N], root, fac[N], f[N], ans, inv[N];
std::vector <int> dis[N];
inline int power(int a, int n);
inline int C(int n, int m);
inline void get_root(int now, int father);
inline void dfs(int now, int father);
signed main(void) {
// file("AT3728");
f[0] = 1; fac[0] = 1; inv[0] = 1;
std::cin >> n;
for (int i = 1; i <= n; i++)
fac[i] = fac[i - 1] * i % mod;
inv[n] = power(fac[n], mod - 2);
for (int i = n - 1; i >= 1; i--)
inv[i] = inv[i + 1] * (i + 1) % mod;
for (int i = 1, x, y; i < n; i++) {
scanf("%lld %lld", &x, &y);
dis[x].emplace_back(y);
dis[y].emplace_back(x);
}
get_root(1, 0);
memset(siz, 0, sizeof(siz));
dfs(root, 0);
for (int t : dis[root]) {
int x = siz[t];
for (int j = n; j >= 0; j--) {
for (int k = 1; k <= std::min(j, x); k++) {
int mul = C(x, k) * fac[x] % mod * inv[x - k] % mod;
f[j] = (f[j] + f[j - k] * mul % mod) % mod;
}
}
}
for (int i = 0; i <= n; i++) {
int flag = 1, num;
if (i % 2 == 1) flag = -1;
num = f[i] * fac[n - i] % mod;
ans = ans + flag * num;
ans = (ans % mod + mod) % mod;
}
std::cout << ans << std::endl;
return 0;
}
inline int power(int a, int n) {
int ret = 1;
while (n) {
if (n & 1) ret = ret * a % mod;
a = a * a % mod;
n /= 2;
}
return ret;
}
inline int C(int n, int m) {
if (n < m) return 0;
int ret = fac[n];
ret = ret * inv[m] % mod;
ret = ret * inv[n - m] % mod;
return ret;
}
inline void get_root(int now, int father) {
siz[now] = 1; w[now] = 0;
for (int t : dis[now]) {
if (t == father) continue;
get_root(t, now);
siz[now] += siz[t];
w[now] = std::max(w[now], siz[t]);
}
w[now] = std::max(w[now], n - siz[now]);
if (w[now] <= n / 2) root = now;
}
inline void dfs(int now, int father) {
siz[now] = 1;
for (int t : dis[now]) {
if (t == father) continue;
dfs(t, now);
siz[now] += siz[t];
}
}