AT3728 [ARC087D] Squirrel Migration

https://www.luogu.com.cn/problem/AT3728

想偏了一些

首先对于每条边,假设断开后两边的大小分别为\(s1,s2\),那么答案的上界就是

\[\sum_{e\in E}2\times \min(s1,s2) \]

考虑把重心设为根,容易发现,上面那个\(\min\)只会取到子树大小,因为子树一定小于等于另外一部分

考虑关于重心的儿子来容斥就行了,设\(f[i]\)表示钦定有\(i\)个点不合法的方案数

不合法定义为\(x,p_x\)在同一棵子树里,得到一颗子树钦定\(j\)个不合法的方案数为\(\binom{size[v]}{j}^2\times j!\)

然后用背包合并即可

最后答案为

\[ANS=\sum_{i=0}^n(-1)^i f[i] \times(n-i)! \]

#include<bits/stdc++.h>
#define N 200050
#define ll long long
#define mod 1000000007
using namespace std;
ll qpow(ll x, ll y) {
    ll ret = 1;
    for(; y; y >>= 1, x = x * x % mod) if(y & 1) ret = ret * x % mod;
    return ret;
}
ll fac[N], ifac[N];
void init(int n) {
    fac[0] = 1;
    for(int i = 1; i <= n; i ++) fac[i] = fac[i - 1] * i % mod;
    ifac[n] = qpow(fac[n], mod - 2);
    for(int i = n - 1; i >= 0; i --) ifac[i] = ifac[i + 1] * (i + 1) % mod;
    //for(int i = 1; i <= n; i ++) printf("%lld ", ifac[i]); printf("\n");
}
ll C(int n, int m) {
    return fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}
int siz[N], msiz[N], n;
vector<int> g[N];
void dfs(int u, int fa) {
    siz[u] = 1; msiz[u] = 0;
    for(int v : g[u]) {
        if(v == fa) continue;
        dfs(v, u); siz[u] += siz[v];
        if(siz[v] > msiz[u]) msiz[u] = siz[v];
    }
    msiz[u] = max(msiz[u], n - siz[u]);
}
ll f[N];
int main() {
    scanf("%d", &n);
    init(n);
    for(int i = 1; i < n; i ++) {
        int u, v;
        scanf("%d%d", &u, &v);
        g[u].push_back(v), g[v].push_back(u);
    }
    dfs(1, 1);
    int rt = 1;
    for(int i = 1; i <= n; i ++) if(msiz[i] < msiz[rt]) rt = i;
    dfs(rt, rt);

    f[0] = 1;
    for(int u : g[rt]) {
        int x = siz[u];
        //printf("%d ", x);
        for(int i = n; i >= 0; i --)
            for(int j = 1; j <= min(i, x); j ++) {
                ll o = C(x, j) * C(x, j) % mod * fac[j] % mod;
                //printf("%d %d  %lld\n", x, j, C(x, j));
                f[i] = (f[i] + f[i - j] * o % mod) % mod;
            }
    }
    //for(int i = 1; i <= n; i ++) printf("%d ", f[i]); printf("\n");
    ll ans = 0;
    for(int i = 0; i <= n; i ++) {
        ll o = f[i] * fac[n - i] % mod;
        if(i & 1) ans = (ans - o + mod) % mod;
        else ans = (ans + o) % mod;
    }
    printf("%lld", ans);
    return 0;
}
posted @ 2022-02-22 20:01  lahlah  阅读(49)  评论(0编辑  收藏  举报