题解 ARC171C【Swap on Tree】

每棵子树内只可能有至多一个外来的数,且外来的数是多少并不影响方案数,因此考虑设 \(f_{u,i,0/1}\) 表示考虑以 \(u\) 为根的子树,与 \(u\) 相连的所有边中断了 \(i\) 条,且 \(u\) 与其父亲之间的边有没有断的方案数。设 \(g_{u,c}=\sum f_{u,i,c}\)

每个节点的初始状态是 \(f_{u,0,0}=1,f_{u,1,1}=[u\ne 1]\)

枚举 \(u\) 的每个儿子 \(v\),进行如下转移:

\[\begin{aligned} f'_{u,i,0}&\gets f_{u,i,0}\times g_{v,0}+[i>0]\times f_{u,i-1,0}\times i\times g_{v,1}\\ f'_{u,i,1}&\gets f_{u,i,1}\times g_{v,0}+[i>0]\times f_{u,i-1,1}\times i\times g_{v,1} \end{aligned} \]

其中两个转移式的前一项代表不断 \((u,v)\) 这条边,后一项代表断 \((u,v)\) 这条边,乘以 \(i\) 是为了确定与 \(u\) 相连的 \(i\) 条边的断边顺序。

时间复杂度 \(O(n^2)\)。上文 \(f,g,f'\) 分别对应代码中的 \(\textrm{dp},\textrm{sum},\textrm{tmp}\)

// Problem: C - Swap on Tree
// Contest: AtCoder - AtCoder Regular Contest 171
// URL: https://atcoder.jp/contests/arc171/tasks/arc171_c
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

//By: OIer rui_er
#include <bits/stdc++.h>
#define rep(x, y, z) for(int x = (y); x <= (z); ++x)
#define per(x, y, z) for(int x = (y); x >= (z); --x)
#define debug(format...) fprintf(stderr, format)
#define fileIO(s) do {freopen(s".in", "r", stdin); freopen(s".out", "w", stdout);} while(false)
#define endl '\n'
using namespace std;
typedef long long ll;

mt19937 rnd(std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::system_clock::now().time_since_epoch()).count());
int randint(int L, int R) {
    uniform_int_distribution<int> dist(L, R);
    return dist(rnd);
}

template<typename T> void chkmin(T& x, T y) {if(x > y) x = y;}
template<typename T> void chkmax(T& x, T y) {if(x < y) x = y;}

template<int mod>
inline unsigned int down(unsigned int x) {
	return x >= mod ? x - mod : x;
}

template<int mod>
struct Modint {
	unsigned int x;
	Modint() = default;
	Modint(unsigned int x) : x(x) {}
	friend istream& operator>>(istream& in, Modint& a) {return in >> a.x;}
	friend ostream& operator<<(ostream& out, Modint a) {return out << a.x;}
	friend Modint operator+(Modint a, Modint b) {return down<mod>(a.x + b.x);}
	friend Modint operator-(Modint a, Modint b) {return down<mod>(a.x - b.x + mod);}
	friend Modint operator*(Modint a, Modint b) {return 1ULL * a.x * b.x % mod;}
	friend Modint operator/(Modint a, Modint b) {return a * ~b;}
	friend Modint operator^(Modint a, int b) {Modint ans = 1; for(; b; b >>= 1, a *= a) if(b & 1) ans *= a; return ans;}
	friend Modint operator~(Modint a) {return a ^ (mod - 2);}
	friend Modint operator-(Modint a) {return down<mod>(mod - a.x);}
	friend Modint& operator+=(Modint& a, Modint b) {return a = a + b;}
	friend Modint& operator-=(Modint& a, Modint b) {return a = a - b;}
	friend Modint& operator*=(Modint& a, Modint b) {return a = a * b;}
	friend Modint& operator/=(Modint& a, Modint b) {return a = a / b;}
	friend Modint& operator^=(Modint& a, int b) {return a = a ^ b;}
	friend Modint& operator++(Modint& a) {return a += 1;}
	friend Modint operator++(Modint& a, int) {Modint x = a; a += 1; return x;}
	friend Modint& operator--(Modint& a) {return a -= 1;}
	friend Modint operator--(Modint& a, int) {Modint x = a; a -= 1; return x;}
	friend bool operator==(Modint a, Modint b) {return a.x == b.x;}
	friend bool operator!=(Modint a, Modint b) {return !(a == b);}
};

const int N = 3e3 + 5;

typedef Modint<998244353> mint;

int n;
mint dp[N][N][2], tmp[N][2], sum[N][2];
vector<int> e[N];

void dfs(int u, int fa) {
	dp[u][0][0] = 1;
	int deg = 0;
	if(fa) {
		dp[u][1][1] = 1;
		++deg;
	}
	for(int v : e[u]) {
		if(v != fa) {
			dfs(v, u);
			++deg;
			rep(i, 0, deg) {
				tmp[i][0] = dp[u][i][0];
				tmp[i][1] = dp[u][i][1];
				dp[u][i][0] = dp[u][i][1] = 0;
			}
			rep(i, 0, deg) {
				dp[u][i][0] += tmp[i][0] * sum[v][0];
				dp[u][i][1] += tmp[i][1] * sum[v][0];
				if(i > 0) {
					dp[u][i][0] += tmp[i - 1][0] * i * sum[v][1];
					dp[u][i][1] += tmp[i - 1][1] * i * sum[v][1];
				}
			}
		}
	}
	rep(i, 0, deg) {
		sum[u][0] += dp[u][i][0];
		sum[u][1] += dp[u][i][1];
	}
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    cin >> n;
    rep(i, 1, n - 1) {
    	int u, v;
    	cin >> u >> v;
    	e[u].push_back(v);
    	e[v].push_back(u);
    }
    dfs(1, 0);
    cout << sum[1][0] << endl;
    return 0;
}
posted @ 2024-02-05 16:17  rui_er  阅读(56)  评论(0编辑  收藏  举报