Loading

QOJ4211 Alice and Bob (树形 dp + 01背包)

QOJ4211 Alice and Bob

树形 dp + 01背包

因为 \(\text{Alice}\) 要赢,说明他走的步数比 \(\text{Bob}\) 更多,所以考虑一个点上的石头最多能让 \(\text{Alice}\)\(\text{Bob}\) 走几步,一定会有一个对两个人都最优的策略(\(\text{Alice}\) 尽量多走,\(\text{Bob}\) 也尽量多走),而这是可以 dp 的。

\(f_u\) 表示从起点 \(u\) 出发,\(\text{Alice}\) 最多能比 \(\text{Bob}\) 走几步,转移:

  1. 如果 \(u\) 是白点,那么 \(\text{Alice}\) 要从 \(\max(f_v+1)\) 转移过来,\(f_u=\max(0,\max(f_v+1))\)
  2. 如果 \(u\) 是黑点,那么 \(\text{Bob}\) 要从 \(\min(f_v-1)\) 转移过来,\(f_v=\min(0,\min(f_v-1))\)

题目现在任意一点都可以放置石头,一种放置方式 \(s\) 能否胜利当且仅当 \(\sum\limits_{i\in s} f_i>0\)。相当于在 \(n\) 个点中选若干个点,使得总和大于 \(0\),要求方案数。这就是经典的 \(01\) 背包问题,设 \(f_{i,j}\) 表示考虑完前 \(i\) 个点的选择,总和为 \(j\) 的方案数,转移易得。

答案就是 \(\sum\limits_{i>0}f_{n,i}\)

复杂度 \(O(n^3)\),因为背包总容量 \(\le n^2\)

#include <bits/stdc++.h>
#define pii std::pair<int, int>
#define mk std::make_pair
#define fi first
#define se second
#define pb push_back

using i64 = long long;
using ull = unsigned long long;
const i64 iinf = 0x3f3f3f3f, linf = 0x3f3f3f3f3f3f3f3f;
const int N = 310, mod = 998244353;
int n, m, ans;
int a[N];
std::string s;
int in[N], f[N], vis[N];
std::vector<int> e[N];
void dfs(int u, int fa) {
	vis[u] = 1;
	int mn = iinf, mx = 0;
	for(auto v : e[u]) {
		if(v == fa) continue;
		if(!vis[v]) dfs(v, u);
		vis[v] = 1;
		if(!a[u]) mx = std::max(mx, f[v] + 1); 
		else mn = std::min(mn, f[v] - 1);
	}
	if(!a[u]) f[u] = std::max(0, mx);
	else f[u] = std::min(0, mn);
}
int dp[N][2 * N * N + N];
int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    
	std::cin >> n >> m >> s;
	for(int i = 0; i < n; i++) a[i + 1] = (s[i] == 'W' ? 0 : 1);
	for(int i = 1; i <= m; i++) {
		int u, v;
		std::cin >> u >> v;
		e[u].pb(v);
		in[v]++;
	}
	for(int i = 1; i <= n; i++) if(!in[i]) dfs(i, 0);
	int sum = 0;
	for(int i = 1; i <= n; i++) sum += f[i];
 	int o = N * N;
 	dp[0][o] = 1;
	for(int i = 1; i <= n; i++) {
		for(int j = n * n; j >= -n * n; j--) {
			if(j - f[i] + o >= 0) dp[i][j + o] = (dp[i - 1][j + o] + dp[i - 1][j - f[i] + o]) % mod;
		}
	}
	for(int i = o + 1; i <= n * n + o; i++) (ans += dp[n][i]) %= mod;

	std::cout << ans << "\n";
	return 0;
}
posted @ 2024-07-03 16:21  Fire_Raku  阅读(86)  评论(0编辑  收藏  举报