Gym-103438C Werewolves 题解
这题我自己还没搞懂,这篇题解是废的。
Gym-103438C Werewolves
题面
有 \(n (1 \le n \le 3000)\) 个节点的树,每个节点的颜色为 \(c_i (1 \le c_i \le n)\)。
请计算这个树存在多少不同的连通子图,满足这个连通子图中,存在某种颜色,其出现次数 严格大于 连通子图中节点数量的一半。
简化题意
first
- 对于任意一个联通子图,如果该联通子图对答案有贡献,则一定只有一种颜色在其中出现次数严格大于连通子图中节点数量的一半。
- 所以可以枚举颜色,将树分为黑白两种颜色。
second
- 对于任意一个联通子图
- 假设是颜色 \(C\) 的节点在其中出现次数严格大于连通子图中节点数量的一半。
- 设 \(x\) 为颜色 \(C\) 的节点数量,\(sum\) 为联通子图的节点数量。
- 该联通子图对答案有贡献,当且仅当(这里的除法都是向下取整):
- \(x > sum \div 2\)
- \(x \times 2 > sum\)
- \(x \times 2 - sum > 0\)
- 设 \(y = x \times 2 - sum\)。
- 依据最后推导出的公式,发现颜色 \(C\) 的节点对 \(y\) 的贡献是 \(1\),其余节点对 \(y\) 的贡献是 \(-1\)。
实现 \(O(n^3)\)
- 化简了题意,接下来就是实现了。
- 这显然是树上背包。(树上背包 dp 的状态设计在最下面)
- 枚举颜色 \(O(n)\),做树上背包初始是 \(O(n^3)\),但是优化后是 \(O(n^2)\),总共 \(O(n^3)\)
- 需要注意的是一下几点(这些雷我都踩了一遍):
滚动数组
- 注意 \(y\),在背包的途中可能会为负数,所以树上背包不可以在其中实现自我滚动,需要手动进行滚动。
- 在手动滚动的时候,记住树上背包是将一棵树不断转化为二叉树进行合并,所以在更新辅助数组的是在枚举儿子的 for 循环中执行。
子树根的答案
- 你可能发现根部的答案好像需要单独处理,但是依据树上背包的转移:将多叉树转化为二叉树,然后一颗二叉树不断合并。
- 那么初始可以将 根 单独看做一颗子树,不断合并。
- 所以只需要将根的值初始化就可以了,不需要在最后将根的答案加入 \(dp\) 数组,不然容易出错,且更方便理解。
优化 \(O(n^3)\)
- 发现我们每个状态都遍历了 \(n\) 次。
- 每个点只有一种颜色,且每个状态对答案的贡献都不大于......
见 https://blog.csdn.net/weixin_45313881/article/details/104158397
简化的题解
- 枚举每个颜色,将所有节点分为黑白两色。(与枚举的颜色相同便为黑,反则白)
- 一个联通子图对答案有贡献当且仅当:\(黑色节点数量 - 白色节点数量 > 0\)。
- 枚举了每个颜色后,做树上背包 dp。
- 设 \(sum = 黑色节点数量 - 白色节点数量\)
- \(dp_{i, j}\) 表示节点 \(i\) 往下选择节点,能得到多少个 \(sum\) 为 \(j\) 的联通子图
- 最后一个 \(ans\) 统计答案
Code
#include <algorithm>
#include <iostream>
#include <vector>
using namespace std;
using LL = long long;
using PII = pair<int, int>;
const int MAXN = 3010 + 3;
const int P = 3005; // 偏移量,因为有负数
const LL mod = 998244353;
int n, L, C, ANS = 0, a[MAXN], mp[MAXN], sz[2][MAXN];
LL tmp[MAXN][2 * MAXN], dp[MAXN][2 * MAXN];
vector<int> eg[MAXN];
void dfs(int x, int dad){
for(int nxt : eg[x]){
if(nxt == dad) continue;
dfs(nxt, x);
for(int j = min(L, sz[0][x]); j >= max(-L, sz[1][x]); j--){
tmp[x][j + P] = dp[x][j + P];
}
for(int i = min(L, sz[0][x]); i >= max(-L, sz[1][x]); i--){
for(int j = min(L, sz[0][nxt]); j >= max(-L, sz[1][nxt]); j--){
dp[x][i + j + P] = (dp[x][i + j + P] + tmp[x][i + P] * dp[nxt][j + P] % mod) % mod;
}
}
sz[0][x] += sz[0][nxt], sz[1][x] += sz[1][nxt];
}
}
int main(){
ios::sync_with_stdio(0), cin.tie(0);
cin >> n;
for(int i = 1; i <= n; i++) cin >> a[i], mp[a[i]]++;
for(int i = 1, U, V; i < n; i++) cin >> U >> V, eg[U].push_back(V), eg[V].push_back(U);
for(C = 1; C <= n; C++){
if(mp[C] <= 1){
ANS += mp[C];
continue;
}
L = mp[C];
for(int x = 1; x <= n; x++){
if(a[x] == C){
sz[0][x] = 1, sz[1][x] = 0, dp[x][P + 1] = (dp[x][P + 1] + 1) % mod;
}else{
sz[0][x] = 0, sz[1][x] = -1, dp[x][P - 1] = (dp[x][P - 1] + 1) % mod;
}
}
dfs(1, 0);
for(int i = 1; i <= n; i++){
for(int j = -mp[C]; j <= mp[C]; j++){
if(j > 0) ANS = (ANS + dp[i][j + P]) % mod;
dp[i][j + P] = 0;
}
}
}
cout << ANS;
return 0;
}