ICPC2022济南站C. DFS Order 2 题解 回滚背包

题目链接:https://www.luogu.com.cn/problem/P9669

题目大意:

给你一棵包含 \(n\) 个节点的有根树。节点编号从 \(1\)\(n\),节点 \(1\) 是根节点。

从节点 \(1\) 出发对整棵树进行深度优先遍历,会得到很多不同的 DFS 序。

解题思路:

基本上和 9981day大佬的题解 一模一样 差不多。

首先考虑某一个节点 \(u\),我这里以 \(way_u\) 表示以 \(u\) 为根的子树一共有多少个不同的 DFS序。

那么很容易得到

\[way_u = son\_sz_u ! \cdot \prod_{i \in son_u} way_v \]

这里 \(son\_sz_u\) 表示 \(u\) 有多少个子节点,这些子节点之间有 \(son\_sz_u !\) 个不同的排列。

其次我们考虑把以 \(u\) 为根的子树缩成一个点(相当于把 \(u\) 的子孙节点全部删掉,把 \(u\) 变成一个叶子节点),在这种情况下计算 \(ans_{u,i}\),它表示把以 \(u\) 为根的子树看一个点的情况下 \(u\) 的 DFS序排在第 \(i\) 位有多少种不同的情况,最终的答案就是 \(ans_{u,i} \times way_u\)

然后就是 \(f_{i,j}\),它表示(父节点是 \(u\),子节点是 \(v\)),在 dfs 时,从 \(u\)\(v\) 隔着 \(i\)\(u\) 的兄弟,这 \(i\) 个兄弟对应的子树大小之和为 \(j\) 时的方案数。这个需要回滚背包。

示例程序:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 505;
const long long mod = 998244353;

long long fpow(long long a, int b) {
    long long t = a % mod, res = 1;
    while (b) {
        if (b & 1)
            res = res * t % mod;
        b >>= 1;
        t = t * t % mod;
    }
    return res;
}

long long inv(long long a) {
    return fpow(a, mod - 2);
}

long long fac[maxn],    // fac[i]: i!(i的阶乘)
        way[maxn],  // way[u]:以u为根进行dfs有多少不同的dfs序
        f[maxn][maxn],  // 当前节点v(父节点u)在u的子树中dfs序排在v前面有i个兄弟,这i个兄弟对应的的子树节点个数之和为j的情况有多少种
        h[maxn],    // h[j]:所有 f[i][j] * j! * (m-1-j)! * way[u] / way[v] / m! 之和(m表示节点u的儿子节点个数)
        ans[maxn][maxn];    // ans[u][i]: 把以u为根的整棵子树看成一个点时节点u是第i个访问到的方案数
                                // 最终输出的是 ans[u][i] * way[u]
int n, sz[maxn],    // sz[u]:以u为根的子树大小
    son_sz[maxn];   // son_sz[u]:节点u的儿子节点个数
vector<int> g[maxn];

void init() {
    fac[0] = 1;
    for (int i = 1; i < maxn; i++)
        fac[i] = fac[i-1] * i % mod;
}

void dfs1(int u, int p) {
    sz[u] = 1;
    son_sz[u] = (u == 1) ? g[u].size() : (g[u].size() - 1);
    way[u] = fac[ son_sz[u] ];
    for (auto v : g[u]) {
        if (v != p) {
            dfs1(v, u);
            sz[u] += sz[v];
            way[u] = way[u] * way[v] % mod;
        }
    }
}

void dfs2(int u, int p) {
    if (u == 1) ans[u][1] = 1;
    int m = son_sz[u];
    for (int i = 0; i <= m; i++)
        for (int j = 0; j <= sz[u]; j++)
            f[i][j] = 0;
    f[0][0] = 1;
    for (auto v : g[u]) {
        if (v != p) {
            for (int i = m; i >= 1; i--)
                for (int j = sz[u]; j >= sz[v]; j--)
                    f[i][j] = (f[i][j] + f[i-1][j-sz[v]]) % mod;
        }
    }
    for (auto v : g[u]) {
        if (v != p) {
            long long ex = way[u] * inv(fac[m]) % mod * inv(way[v]) % mod;

            // 删除v的这一部分
            for (int i = 1; i <= m; i++)
                for (int j = sz[v]; j <= sz[u]; j++)
                    f[i][j] = (f[i][j] - f[i-1][j-sz[v]] + mod) % mod;
            // 处理
            fill(h, h+sz[u]+1, 0);
            for (int i = 0; i <= m; i++)
                for (int j = 0; j <= sz[u]; j++)
                    h[j] = (h[j] + f[i][j] * fac[i] % mod * fac[ m-i-1 ] % mod * ex % mod) % mod;
            for (int i = 1; i <= n; i++) {
                for (int j = 0; j <= sz[u] && i+j+1 <= n; j++) {
                    ans[v][i+j+1] = (ans[v][i+j+1] + ans[u][i] * h[j]) % mod;
                }
            }
            // 撤销对v的这一部分的删除
            for (int i = m; i >= 1; i--)
                for (int j = sz[u]; j >= sz[v]; j--)
                    f[i][j] = (f[i][j] + f[i-1][j-sz[v]]) % mod;
        }
    }
    for (auto v : g[u]) // debug两天,出错原因就是因为把这段代码合到上一个循环里面了
        if (v != p)
            dfs2(v, u);
}

int main() {
    init();
    scanf("%d", &n);
    for (int i = 1; i < n; i++) {
        int u, v;
        scanf("%d%d", &u, &v);
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs1(1, -1);
    dfs2(1, -1);
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= n; j++) {
            if (j > 1) putchar(' ');
            printf("%lld", ans[i][j] * way[i] % mod);
        }
        puts("");
    }
    return 0;
}
posted @ 2024-11-26 20:46  quanjun  阅读(8)  评论(0编辑  收藏  举报