题解 ARC171C【Swap on Tree】
每棵子树内只可能有至多一个外来的数,且外来的数是多少并不影响方案数,因此考虑设 \(f_{u,i,0/1}\) 表示考虑以 \(u\) 为根的子树,与 \(u\) 相连的所有边中断了 \(i\) 条,且 \(u\) 与其父亲之间的边有没有断的方案数。设 \(g_{u,c}=\sum f_{u,i,c}\)。
每个节点的初始状态是 \(f_{u,0,0}=1,f_{u,1,1}=[u\ne 1]\)。
枚举 \(u\) 的每个儿子 \(v\),进行如下转移:
\[\begin{aligned}
f'_{u,i,0}&\gets f_{u,i,0}\times g_{v,0}+[i>0]\times f_{u,i-1,0}\times i\times g_{v,1}\\
f'_{u,i,1}&\gets f_{u,i,1}\times g_{v,0}+[i>0]\times f_{u,i-1,1}\times i\times g_{v,1}
\end{aligned}
\]
其中两个转移式的前一项代表不断 \((u,v)\) 这条边,后一项代表断 \((u,v)\) 这条边,乘以 \(i\) 是为了确定与 \(u\) 相连的 \(i\) 条边的断边顺序。
时间复杂度 \(O(n^2)\)。上文 \(f,g,f'\) 分别对应代码中的 \(\textrm{dp},\textrm{sum},\textrm{tmp}\)。
// Problem: C - Swap on Tree
// Contest: AtCoder - AtCoder Regular Contest 171
// URL: https://atcoder.jp/contests/arc171/tasks/arc171_c
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
//By: OIer rui_er
#include <bits/stdc++.h>
#define rep(x, y, z) for(int x = (y); x <= (z); ++x)
#define per(x, y, z) for(int x = (y); x >= (z); --x)
#define debug(format...) fprintf(stderr, format)
#define fileIO(s) do {freopen(s".in", "r", stdin); freopen(s".out", "w", stdout);} while(false)
#define endl '\n'
using namespace std;
typedef long long ll;
mt19937 rnd(std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::system_clock::now().time_since_epoch()).count());
int randint(int L, int R) {
uniform_int_distribution<int> dist(L, R);
return dist(rnd);
}
template<typename T> void chkmin(T& x, T y) {if(x > y) x = y;}
template<typename T> void chkmax(T& x, T y) {if(x < y) x = y;}
template<int mod>
inline unsigned int down(unsigned int x) {
return x >= mod ? x - mod : x;
}
template<int mod>
struct Modint {
unsigned int x;
Modint() = default;
Modint(unsigned int x) : x(x) {}
friend istream& operator>>(istream& in, Modint& a) {return in >> a.x;}
friend ostream& operator<<(ostream& out, Modint a) {return out << a.x;}
friend Modint operator+(Modint a, Modint b) {return down<mod>(a.x + b.x);}
friend Modint operator-(Modint a, Modint b) {return down<mod>(a.x - b.x + mod);}
friend Modint operator*(Modint a, Modint b) {return 1ULL * a.x * b.x % mod;}
friend Modint operator/(Modint a, Modint b) {return a * ~b;}
friend Modint operator^(Modint a, int b) {Modint ans = 1; for(; b; b >>= 1, a *= a) if(b & 1) ans *= a; return ans;}
friend Modint operator~(Modint a) {return a ^ (mod - 2);}
friend Modint operator-(Modint a) {return down<mod>(mod - a.x);}
friend Modint& operator+=(Modint& a, Modint b) {return a = a + b;}
friend Modint& operator-=(Modint& a, Modint b) {return a = a - b;}
friend Modint& operator*=(Modint& a, Modint b) {return a = a * b;}
friend Modint& operator/=(Modint& a, Modint b) {return a = a / b;}
friend Modint& operator^=(Modint& a, int b) {return a = a ^ b;}
friend Modint& operator++(Modint& a) {return a += 1;}
friend Modint operator++(Modint& a, int) {Modint x = a; a += 1; return x;}
friend Modint& operator--(Modint& a) {return a -= 1;}
friend Modint operator--(Modint& a, int) {Modint x = a; a -= 1; return x;}
friend bool operator==(Modint a, Modint b) {return a.x == b.x;}
friend bool operator!=(Modint a, Modint b) {return !(a == b);}
};
const int N = 3e3 + 5;
typedef Modint<998244353> mint;
int n;
mint dp[N][N][2], tmp[N][2], sum[N][2];
vector<int> e[N];
void dfs(int u, int fa) {
dp[u][0][0] = 1;
int deg = 0;
if(fa) {
dp[u][1][1] = 1;
++deg;
}
for(int v : e[u]) {
if(v != fa) {
dfs(v, u);
++deg;
rep(i, 0, deg) {
tmp[i][0] = dp[u][i][0];
tmp[i][1] = dp[u][i][1];
dp[u][i][0] = dp[u][i][1] = 0;
}
rep(i, 0, deg) {
dp[u][i][0] += tmp[i][0] * sum[v][0];
dp[u][i][1] += tmp[i][1] * sum[v][0];
if(i > 0) {
dp[u][i][0] += tmp[i - 1][0] * i * sum[v][1];
dp[u][i][1] += tmp[i - 1][1] * i * sum[v][1];
}
}
}
}
rep(i, 0, deg) {
sum[u][0] += dp[u][i][0];
sum[u][1] += dp[u][i][1];
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0); cout.tie(0);
cin >> n;
rep(i, 1, n - 1) {
int u, v;
cin >> u >> v;
e[u].push_back(v);
e[v].push_back(u);
}
dfs(1, 0);
cout << sum[1][0] << endl;
return 0;
}