P8935 [JRKSJ R7] 茎 解题报告

Description

你有一棵 $n$ 个点的根节点为 $1$ 的有根树,现在你要对这棵树进行剪枝,每次你可以选择一个还未被剪掉的节点 $u$ 进行操作,然后剪掉 $u$ 的子树所有点(包括 $u$)。当且仅当你剪掉 $1$ 时,操作停止。

你知道 $1$ 到 $x$ 这条路径是这棵树的茎,需要特殊处理。所以你需要在第 kk 次剪枝时对 $x$ 进行操作,而非仅仅将其剪掉,即你不能在第 $k$ 次及以前对其祖先进行操作使其被连带剪掉。

求有多少种不同的操作序列,两个操作序列不同当且仅当长度不同或存在一次操作 $i$ 使得两操作序列在第 $i$ 次时选择的 $u$ 不同。输出答案模 $10^9+7$。

Solution

题解区里的 $O(n^2)$ 做法,先膜拜一下。

先看部分分, 发现有 $x=1$ 和 $k=1$ 的subtask,因此我们先对每颗子树做一遍树形背包,用$f_{i,j}$ 表示以 $i$ 为根节点的子树进行 $j$ 次操作的方案树,贴个代码。

inline void dfs(int u, int F) {
    fa[u] = F; f[u][0] = 1;
    for (int i = head[u]; i; i = Next[i]) {
        int v = ver[i]; if (v == F) continue;
        dfs(v, u);
        for (int j = siz[u]; j >= 0; -- j)
            for (int k = siz[v]; k; -- k) f[u][j+k] = (f[u][j+k] + f[u][j] * f[v][k] % mod * C(j+k, k)) % mod;
        siz[u] += siz[v];
    } 
    for (int i = siz[u]; i >= 0; -- i) f[u][i+1] = (f[u][i+1] + f[u][i]) % mod;
    siz[u] ++;
}
View Code

 

我们将茎上的点和它的子树区分开来, 然后在做一遍背包 $w_i$。这个可以每次做到茎上的一个点再调用计算。

inline int get(int u) {
    int nc = 0; memset(w, 0, sizeof w); w[0] = 1;
    for (int i = head[u]; i; i = Next[i]) {
        int v = ver[i]; if (flag[v]) continue;
        for (int j = nc; j >= 0; -- j)
            for (int k = siz[v]; k; -- k) w[j+k] = (w[j+k] + w[j] * f[v][k] % mod * C(j+k, k)) % mod;
        nc += siz[v];
    }
    return nc;
}
View Code

 


我们从跟 $1$ 开始往 $x$ 开始 dp。

用 $g_{i,j}$ 表示选到茎上的节点 $i$, $j$ 次操作留着的方案数, 最终答案就是 $g_{x, k-1}$。

考虑 $g_{i,j}$ 的转移。

从 $fa$ 转移到 $i$ 时,有两种情况,对 $i$进行操作或者跳过,如果跳过,直接拷贝即可,如果操作,$g_{i,j}$ 可以从 $g_{fa, t}, t >= j$ 转移过来,用后缀和优化。

然后 $i$ 自身的转移,就是子树背包。

这道题就做完了。

Code

#include<bits/stdc++.h>
using namespace std;
#define int long long
inline int read() {
    int x = 0, f = 1; char c = getchar();
    while (c < '0' || c > '9') {if (c == '-') f = -f; c = getchar();}
    while (c >= '0' && c <= '9') {x = (x << 3) + (x << 1) + (c ^ 48); c = getchar();}
    return x * f;
}
const int N = 505, mod = 1e9 + 7;
int n, k, P;
int fac[N], inv[N];
inline int qpow(int a, int b) {int res = 1; for (; b; b >>= 1, a = a * a % mod) if (b & 1) res = res * a % mod; return res;}
inline int C(int x, int y) {if (x < y || y < 0) return 0; return fac[x] * inv[x-y] % mod * inv[y] % mod;}
int tot, head[N], Next[N<<1], ver[N<<1];
inline void add_edge(int u, int v) {ver[++tot] = v; Next[tot] = head[u]; head[u] = tot;}
int fa[N], f[N][N], siz[N];
inline void dfs(int u, int F) {
    fa[u] = F; f[u][0] = 1;
    for (int i = head[u]; i; i = Next[i]) {
        int v = ver[i]; if (v == F) continue;
        dfs(v, u);
        for (int j = siz[u]; j >= 0; -- j)
            for (int k = siz[v]; k; -- k) f[u][j+k] = (f[u][j+k] + f[u][j] * f[v][k] % mod * C(j+k, k)) % mod;
        siz[u] += siz[v];
    } 
    for (int i = siz[u]; i >= 0; -- i) f[u][i+1] = (f[u][i+1] + f[u][i]) % mod;
    siz[u] ++;
}
int w[N], flag[N];
inline int get(int u) {
    int nc = 0; memset(w, 0, sizeof w); w[0] = 1;
    for (int i = head[u]; i; i = Next[i]) {
        int v = ver[i]; if (flag[v]) continue;
        for (int j = nc; j >= 0; -- j)
            for (int k = siz[v]; k; -- k) w[j+k] = (w[j+k] + w[j] * f[v][k] % mod * C(j+k, k)) % mod;
        nc += siz[v];
    }
    return nc;
}
int nc, p[N], g[2][N];
signed main () {
    n = read(); k = read(); P = read(); fac[0] = 1;
    for (int i = 1; i <= n; ++ i) fac[i] = fac[i-1] * i % mod;
    inv[n] = qpow(fac[n], mod - 2);
    for (int i = n; i; -- i) inv[i-1] = inv[i] * i % mod;
    for (int i = 1; i < n; ++ i) {
        int u = read(), v = read();
        add_edge(u, v); add_edge(v, u);
    }
    dfs(1, 0); int x = P, tc = 0; 
    while (x) flag[x] = 1, p[++tc] = x, x = fa[x];
    reverse(p + 1, p + 1 + tc); get(1);
    for (int i = 0; i < n; ++ i) g[1][i] = w[i];
    for (int i = 2; i <= tc; ++ i) {
        int now = i & 1, pre = now ^ 1, u = p[i], nc = get(u);
        memcpy(g[now], g[pre], sizeof g[now]); 
        int sum = 0;
        for (int j = n - 1; j >= 0; -- j) {
            sum = (sum + g[now][j]) % mod;
            if (i == tc) g[now][j] = 0;
            g[now][j] = (g[now][j] + sum) % mod;
        } 
        for (int j = n - 1; j >= 0; -- j)
            for (int k = nc; k; -- k) g[now][j+k] = (g[now][j+k] + g[now][j] * w[k] % mod * C(j+k, k)) % mod;
    }
    printf("%lld\n", g[tc&1][k-1]);
    return 0;
}
View Code

 

posted @ 2023-01-10 11:39  LikC1606  阅读(29)  评论(0编辑  收藏  举报