ICPC2022济南站C. DFS Order 2 题解 回滚背包
题目链接:https://www.luogu.com.cn/problem/P9669
题目大意:
给你一棵包含 \(n\) 个节点的有根树。节点编号从 \(1\) 到 \(n\),节点 \(1\) 是根节点。
从节点 \(1\) 出发对整棵树进行深度优先遍历,会得到很多不同的 DFS 序。
解题思路:
基本上和 9981day大佬的题解 一模一样 差不多。
首先考虑某一个节点 \(u\),我这里以 \(way_u\) 表示以 \(u\) 为根的子树一共有多少个不同的 DFS序。
那么很容易得到
\[way_u = son\_sz_u ! \cdot \prod_{i \in son_u} way_v
\]
这里 \(son\_sz_u\) 表示 \(u\) 有多少个子节点,这些子节点之间有 \(son\_sz_u !\) 个不同的排列。
其次我们考虑把以 \(u\) 为根的子树缩成一个点(相当于把 \(u\) 的子孙节点全部删掉,把 \(u\) 变成一个叶子节点),在这种情况下计算 \(ans_{u,i}\),它表示把以 \(u\) 为根的子树看一个点的情况下 \(u\) 的 DFS序排在第 \(i\) 位有多少种不同的情况,最终的答案就是 \(ans_{u,i} \times way_u\)。
然后就是 \(f_{i,j}\),它表示(父节点是 \(u\),子节点是 \(v\)),在 dfs 时,从 \(u\) 到 \(v\) 隔着 \(i\) 个 \(u\) 的兄弟,这 \(i\) 个兄弟对应的子树大小之和为 \(j\) 时的方案数。这个需要回滚背包。
示例程序:
#include <bits/stdc++.h>
using namespace std;
const int maxn = 505;
const long long mod = 998244353;
long long fpow(long long a, int b) {
long long t = a % mod, res = 1;
while (b) {
if (b & 1)
res = res * t % mod;
b >>= 1;
t = t * t % mod;
}
return res;
}
long long inv(long long a) {
return fpow(a, mod - 2);
}
long long fac[maxn], // fac[i]: i!(i的阶乘)
way[maxn], // way[u]:以u为根进行dfs有多少不同的dfs序
f[maxn][maxn], // 当前节点v(父节点u)在u的子树中dfs序排在v前面有i个兄弟,这i个兄弟对应的的子树节点个数之和为j的情况有多少种
h[maxn], // h[j]:所有 f[i][j] * j! * (m-1-j)! * way[u] / way[v] / m! 之和(m表示节点u的儿子节点个数)
ans[maxn][maxn]; // ans[u][i]: 把以u为根的整棵子树看成一个点时节点u是第i个访问到的方案数
// 最终输出的是 ans[u][i] * way[u]
int n, sz[maxn], // sz[u]:以u为根的子树大小
son_sz[maxn]; // son_sz[u]:节点u的儿子节点个数
vector<int> g[maxn];
void init() {
fac[0] = 1;
for (int i = 1; i < maxn; i++)
fac[i] = fac[i-1] * i % mod;
}
void dfs1(int u, int p) {
sz[u] = 1;
son_sz[u] = (u == 1) ? g[u].size() : (g[u].size() - 1);
way[u] = fac[ son_sz[u] ];
for (auto v : g[u]) {
if (v != p) {
dfs1(v, u);
sz[u] += sz[v];
way[u] = way[u] * way[v] % mod;
}
}
}
void dfs2(int u, int p) {
if (u == 1) ans[u][1] = 1;
int m = son_sz[u];
for (int i = 0; i <= m; i++)
for (int j = 0; j <= sz[u]; j++)
f[i][j] = 0;
f[0][0] = 1;
for (auto v : g[u]) {
if (v != p) {
for (int i = m; i >= 1; i--)
for (int j = sz[u]; j >= sz[v]; j--)
f[i][j] = (f[i][j] + f[i-1][j-sz[v]]) % mod;
}
}
for (auto v : g[u]) {
if (v != p) {
long long ex = way[u] * inv(fac[m]) % mod * inv(way[v]) % mod;
// 删除v的这一部分
for (int i = 1; i <= m; i++)
for (int j = sz[v]; j <= sz[u]; j++)
f[i][j] = (f[i][j] - f[i-1][j-sz[v]] + mod) % mod;
// 处理
fill(h, h+sz[u]+1, 0);
for (int i = 0; i <= m; i++)
for (int j = 0; j <= sz[u]; j++)
h[j] = (h[j] + f[i][j] * fac[i] % mod * fac[ m-i-1 ] % mod * ex % mod) % mod;
for (int i = 1; i <= n; i++) {
for (int j = 0; j <= sz[u] && i+j+1 <= n; j++) {
ans[v][i+j+1] = (ans[v][i+j+1] + ans[u][i] * h[j]) % mod;
}
}
// 撤销对v的这一部分的删除
for (int i = m; i >= 1; i--)
for (int j = sz[u]; j >= sz[v]; j--)
f[i][j] = (f[i][j] + f[i-1][j-sz[v]]) % mod;
}
}
for (auto v : g[u]) // debug两天,出错原因就是因为把这段代码合到上一个循环里面了
if (v != p)
dfs2(v, u);
}
int main() {
init();
scanf("%d", &n);
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
}
dfs1(1, -1);
dfs2(1, -1);
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
if (j > 1) putchar(' ');
printf("%lld", ans[i][j] * way[i] % mod);
}
puts("");
}
return 0;
}