[做题记录-计数] [ARC087D] Squirrel Migration
题目描述
给你一个\(N\)个节点的树,求一个\(1\cdots N\)的排列\((p_1,p_2,\cdots p_N)\) ,使得\(\sum dist(i,p_i)\)最大。
求这样的排列的个数。答案对\(10^9+7\)取模。
\(N \leq 5000\)
Solution
笑死了差点又抄题解去了。
真就只会抄题解哈哈哈哈哈哈哈哈哈哈哈哈。
可以发现一条边的贡献最大的时候是\(2\times min(sz_u, sz_v)\)。
然后一步猜测可以知道贡献最大当且仅当每条边都会满足小的那个子树里面的点都配对出去了。如果没有配对出去, 从外面拉一条进来一定可以更优。
这个判定条件是什么?就是把重心当做根以后所有配对是跨子树的。
那么有两个重心的时候答案就是\((\frac{n}{2}!)^2\)。有一个重心的时候就直接对重心所有子树做背包表示有\(i\)个匹配在自己的子树里面, 直接容斥算答案即可。
/*
QiuQiu /qq
____ _ _ __
/ __ \ (_) | | / /
| | | | _ _ _ | | _ _ / / __ _ __ _
| | | | | | | | | | | | | | | | / / / _` | / _` |
| |__| | | | | |_| | | | | |_| | / / | (_| | | (_| |
\___\_\ |_| \__,_| |_| \__, | /_/ \__, | \__, |
__/ | | | | |
|___/ |_| |_|
*/
#include <bits/stdc++.h>
using namespace std;
const int N = 5000 + 10;
const int P = 1e9 + 7;
inline int power(int x, int k) {
int res = 1;
while(k) { if(k & 1) res = 1ll * x * res % P; x = 1ll * x * x % P; k >>= 1; } return res;
}
inline void pls(int &x, int y) { x += y; if(x >= P) x -= P; }
inline void dec(int &x, int y) { x -= y; if(x < 0) x += P; }
inline int mod(int x, int y) { x += y; if(x >= P) x -= P; return x; }
int n;
vector<int> e[N];
int sz[N], cnt, rt, fa[N];
void dfs(int x, int fx) {
sz[x] = 1; int mx = 0; fa[x] = fx;
for(int y : e[x]) if(y != fx) { dfs(y, x); sz[x] += sz[y]; mx = max(mx, sz[y]); }
mx = max(mx, n - sz[x]);
if(mx * 2 <= n) cnt ++, rt = x;
}
int fac[N], ifac[N];
inline int C(int x, int y) {
if(x < 0 || y < 0 || x < y) return 0;
return 1ll * fac[x] * ifac[y] % P * ifac[x - y] % P;
}
int dp[N][N];
int main() {
ios :: sync_with_stdio(0);
cin >> n;
for(int i = 2, x, y; i <= n; i ++) {
cin >> x >> y;
e[x].push_back(y); e[y].push_back(x);
}
dfs(1, 0);
fac[0] = 1;
for(int i = 1; i <= n; i ++) fac[i] = 1ll * fac[i - 1] * i % P;
ifac[n] = power(fac[n], P - 2);
for(int i = n - 1; i >= 0; i --) ifac[i] = 1ll * ifac[i + 1] * (i + 1) % P;
if(cnt == 2) { cout << 1ll * fac[n / 2] * fac[n / 2] % P << endl; return 0; }
static int a[N]; int top = 0;
for(int y : e[rt]) {
if(y == fa[rt]) a[++ top] = n - sz[rt];
else a[++ top] = sz[y];
}
dp[0][0] = 1; int sum = 0;
for(int i = 1; i <= top; i ++) {
sum += a[i];
for(int j = 0; j <= sum; j ++)
for(int k = 0; k <= min(a[i], j); k ++) {
pls(dp[i][j], 1ll * dp[i - 1][j - k] * C(a[i], k) % P * C(a[i], k) % P * fac[k] % P);
}
}
int ans = 0;
for(int i = 0; i <= n; i ++) {
int res = 1ll * dp[top][i] * fac[n - i] % P;
if(i & 1) dec(ans, res);
else pls(ans, res);
}
cout << ans << endl;
return 0;
}