[CTS2019]氪金手游
看到本题以后,想到了之前做过的一道题 [HEOI2013]SAO,用类似的方法思考了一下,发现不能这样搞,因为在那道题中是可以将一段元素随意插到一段中间的,但在本题中还需要乘上相应的概率,这个概率非常不好算。
于是只能换一个角度思考,直接解决这个问题貌似很难可以考虑先从特殊的问题出发。先从部分分的一条链下手,但是直接做这条链还是好像不好做,可以考虑进一步特殊化这个问题,先思考所有边同向的情况,即所有边都是 \(i \rightarrow i + 1\) 的情况。你会发现这个问题等价于 \(i\) 要在 \(i + 1 \sim n\) 的所有数之前取到,先不考虑 \(w_i\) 不确定的情况,令 \(S_i = \sum\limits_{j = 1} ^ i w_j\),则对于每个 \(i\) 它在 \(i + 1 \sim n\) 中所有数之前取到的概率为:
你会发现每个 \(i\) 的概率只依赖于 \(w_i\) 以及 \(w_i, w_{i + 1}, \cdots, w_n\) 的和,可以考虑反过来 \(dp\),令 \(dp_{i, j}\) 考虑到第 \(i\) 个数,\(\sum\limits_{k = i} ^ n w_k = j\) 的概率,转移时考虑这个位置填什么即可。
接下来继续考虑,如果存在一条反向边怎么办?这意味着我们要破坏这个抽奖得到的连续性,这样就非常不好计算了,有什么办法能将这个问题继续转化为刚刚的连续性问题吗?一个直接的想法就是将这条链在反向边处分为两块,将两块合法的概率乘起来再乘满足反向边的概率。但这样是不行的,因为你在算上面的时候并不知道后面能取谁,也就是说上下两块概率是会互相影响不能分开计算的。那么我们就要想一个办法将两块概率放在一起连续算,你会发现这样的代价相当于将这条反向边变为正向边,这给了我们极大启发,可以考虑钦定一些反向边变为正向边,其他反向边随意。你会发现这相当于将整个问题划分成一些块的连续型的子问题,最后容斥一下即可。
接下来你会发现我们并不需要直接暴力枚举哪些反向边反向,这个问题可以直接使用一个 \(dp\) 解决。令 \(dp_{i, j, k}\) 表示考虑到第 \(i\) 个数,当前已经钦定了 \(j\) 条反向边反向,当前块的大小为 \(k\) 的概率,转移考虑当前位置填 \(1 / 2 / 3\) 或者是反向边反向即可。那么 \(w_i\) 不确定怎么办呢?你会发现我们只需要在转移填 \(1 / 2 / 3\) 时乘上对应的概率即可。有以下转移(\(p_{i, j}\) 表示 \(i\) 选 \(j\) 的概率):
反向边反向:
反向边不反向:
最后的答案:
但是这个 \(dp\) 仍然是 \(O(n ^ 3)\) 的,显然第一维和第三维已经不能再压了,只能考虑在第二维上下手脚。观察一下这个转移的过程,你会发现我们只会做乘法,并且我们并不关乎到底有多少条反向边反向而只需要考虑其对答案的贡献即可。那么为什么我们不直接把第二维压掉,在每次反向边反向是将答案乘上 \(-1\) 呢?你会发现我们这样就很好地利用转移性质解决了这个问题。
最后我们将目光放到树上,你会发现这和序列上是同样一个问题,只是 \(i\) 要在整个子树之前抽到的概率为 \(\dfrac{w_i}{S_i}\)(\(S_i\) 为以 \(i\) 为根的子树内 \(w_j\) 的和)。于是我们令 \(dp_{i, j}\) 表示以 \(i\) 为根的子树内当前选择的块大小为 \(j\) 的概率,转移利用树形背包技巧即可做到 \(O(n ^ 2)\)。
#include<bits/stdc++.h>
using namespace std;
#define rep(i, l, r) for(int i = l; i <= r; ++i)
#define Next(i, u) for(int i = h[u]; i; i = e[i].next)
const int N = 1000 + 5;
const int Mod = 998244353;
struct edge{
int v, dir, next;
}e[N << 1];
int n, u, v, tot, ans, h[N], s[N], f[3 * N], inv[3 * N], a[N][5], dp[N][3 * N];
int read(){
char c; int x = 0, f = 1;
c = getchar();
while(c > '9' || c < '0'){ if(c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x * f;
}
int Inc(int a, int b){ return (a += b) >= Mod ? a - Mod : a;}
int Dec(int a, int b){ return (a -= b) < 0 ? a + Mod : a;}
int Mul(int a, int b){ return 1ll * a * b % Mod;}
int Qpow(int a, int b){ int ans = 1; for(; b; a = Mul(a, a), b >>= 1) if(b & 1) ans = Mul(ans, a); return ans;}
void add(int u, int v){
e[++tot].v = v, e[tot].dir = 0, e[tot].next = h[u], h[u] = tot;
e[++tot].v = u, e[tot].dir = 1, e[tot].next = h[v], h[v] = tot;
}
void dfs(int u, int fa){
int S = Qpow((a[u][1] + a[u][2] + a[u][3]) % Mod, Mod - 2); s[u] = 3;
dp[u][1] = Mul(a[u][1], S), dp[u][2] = Mul(2, Mul(a[u][2], S)), dp[u][3] = Mul(3, Mul(a[u][3], S));
Next(i, u){
int v = e[i].v; if(v == fa) continue;
dfs(v, u); rep(j, 1, s[u] + s[v]) f[j] = 0;
rep(j, 1, s[u]) if(dp[u][j]){
rep(k, 1, s[v]) if(dp[v][k]){
if(e[i].dir){
f[j + k] = Dec(f[j + k], Mul(dp[u][j], dp[v][k]));
f[j] = Inc(f[j], Mul(dp[u][j], dp[v][k]));
}
else f[j + k] = Inc(f[j + k], Mul(dp[u][j], dp[v][k]));
}
}
s[u] += s[v]; rep(j, 1, s[u]) dp[u][j] = f[j];
}
rep(i, 1, s[u]) dp[u][i] = Mul(dp[u][i], inv[i]);
}
int main(){
n = read();
rep(i, 1, n * 3) inv[i] = Qpow(i, Mod - 2);
rep(i, 1, n) rep(j, 1, 3) a[i][j] = read();
rep(i, 1, n - 1) u = read(), v = read(), add(u, v);
dfs(1, 0);
rep(i, 1, 3 * n) ans = Inc(ans, dp[1][i]);
printf("%d", ans);
return 0;
}