XMUPC2022 H Little M with Glass Balls 【DP,虚树】

传送门

思路

一个简单的观察是,在原树中深度不同的球是独立的,也就是说我们可以将原树按照深度拆分成 \(O(n)\) 颗树,其中第 \(i\) 颗树所有叶子深度均为 \(i\)。根据期望的线性性我们可以知道,对于这 \(O(n)\) 颗树分别求答案后相加就是原问题的答案。

接下来我们考虑对于转化后的问题我们如何求答案。有两种思考的角度:计算每个球对答案的贡献/计算每条边对答案的贡献,这里选用的是第二种,不知道第一种有没有可行的办法。考虑每条边对答案的贡献,对于一条边 \((u,fa_u)\),如果我们能求出 \(f_u\) 表示经过若干秒后 \(u\) 有球的概率,\(g_u\) 表示某个球从 \(u\) 开始往上走若干步后被回收的概率,那么 \((u,fa_u)\) 对答案的贡献为 \(f_u \times (1 - p_u) \times g_{fa_u}\)。现在问题就变成了如何求出 \(f_u\)\(g_u\)

考虑 DP,设 \(f_{u,0/1}\) 表示经过若干秒后,\(u\) 子树的所有球都寄了(碰撞或是被回收),此时点 \(u\) 有/没有球的概率,有转移:

\[f_{u,1} = \sum_{v \in son_u} f_{v,1} \times (1 - p_v) \times \prod_{k \in son_u} (f_{k,0} + f_{k,1} \times p_k) \]

\[f_{u,0} = 1 - f_{u,1} \]

对第一个转移的解释:在 \(f_{u,1}\) 的转移中,我们枚举了这个球从哪个子树中来,这要求其他子树的球都不能到达 \(u\)。这有两种情况:球根本没到点 \(k\),或者球到了点 \(k\),但是被回收了。预处理前后缀积可以 \(O(1)\) 转移。

对于 \(g_u\) 的转移,我们先求出 \(up_u\) 表示在点 \(u\) 的球成功到达 \(fa_u\) 的概率,那么有转移:

\[up_u = (1 - p_u) \times \prod_{v \in son_u} (f_{v,0} + f_{v,1} \times p_v) \]

然后你会发现这个东西我们在求 \(f\) 的时候已经求过一遍了,只需要在求 \(f\) 的时候把值顺便传给儿子的 \(up\) 即可。对于 \(g\) 我们采用从上到下的转移方式:

\[g_u = p_u + up_u \times g_{fa_u} \]

对这个转移的解释:有两种情况,要么在点 \(u\) 就被回收,要么成功到达点 \(fa_u\),之后就变成了 \(fa_u\) 的子问题。于是我们成功在 \(O(n)\) 的时间复杂度内求出了 \(f\)\(g\),但由于我们需要对 \(O(n)\) 颗树都求一遍答案,于是时间复杂度就变成了 \(O(n^2)\)

但观察 DP 式子,容易发现每次 DP 时有很多地方转移的系数是完全一致的,更具体地说,这些球只可能在很少的点发生碰撞。于是考虑建出虚树,容易发现只有在虚树上的点可能发生碰撞,而对于被缩起来的链我们需要查询 \(coef_{0/1}\) 表示在这条链上被/不被回收的概率,这可以 DP 出前缀答案,然后通过类似树上差分的方式求出。

然后是计算贡献的部分,这部分比较繁琐一些。考虑一条被缩起来的链 \(v \to u\),它对答案的贡献可以分为两部分:球从 \(v\) 成功到达 \(u\) 和球在 \(v \to u\) 的过程中被回收的贡献。第一部分的贡献显然是 \(len_{v \to u} \times f_{v,1} \times up_{v \to u, 1} \times g_u\),第二部分的贡献同样可以 DP 出前缀答案然后用类似差分的方式求出,具体做法留做习题(也可以参照代码理解)。这样每次 DP 的复杂度只跟虚树的点数有关。

分析一下复杂度:由于这 \(O(n)\) 颗树中的球数之和为 \(n\),那么虚树的总点数就是 \(O(n)\) 的,于是 DP 的时间复杂度也是 \(O(n)\) 的。于是总时间复杂度为 \(O(n \log n)\),瓶颈仅在于建虚树。

一点恶心的细节:上述做法需要查询链积,由于 \(p_i\)\(1 - p_i\) 都可能是 \(0\),于是我们不能直接求逆元。代码里用的处理方法是:将所有值为 \(0\) 的点标记为关键点,这些关键点将原树分成了若干连通块。我们不再处理前缀积,而是改为处理所在连通块的前缀积,查询的时候先差分求出链上是否存在关键点,如果存在那么积为 \(0\),否则可以直接计算。不过这还需要对 \(O(n)\) 个数求逆元,但大家应该都会 \(O(n + \log p)\),所以就不说了。

一点想法:实际上有一个性质,对于原树上的每个点,其至多进出虚树一次,不知道根据这个性质能不能得到 \(O(n)\) 的做法。我目前还没有什么发现,先不管了。

代码

/*
也许所有的执念 就像四季的更迭
没有因缘 不需致歉
是否拥抱着告别 就更能读懂人间
还是感慨 更多一点
*/
#include <bits/stdc++.h>
#define pii pair<int, int>
#define mp(x, y) make_pair(x, y)
#define pb push_back
#define fi first
#define se second
#define int long long
#define mem(x, v, n) memset(x, v, sizeof(int) * (n))
#define mcpy(x, y, n) memcpy(x, y, sizeof(int) * (n))
#define lob lower_bound
#define upb upper_bound
using namespace std;

inline int read() {
	int x = 0, w = 1;char ch = getchar();
	while (ch > '9' || ch < '0') { if (ch == '-')w = -1;ch = getchar(); }
	while (ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar();
	return x * w;
}

inline int min(int x, int y) { return x < y ? x : y; }
inline int max(int x, int y) { return x > y ? x : y; }

const int MN = 1e6 + 5;
const int Mod = 998244353;
const int Inf = 1e9;

inline void Add(int &x, int y) { x += y; if (x >= Mod) x -= Mod; }
inline void Dec(int &x, int y) { x -= y; if (x < 0) x += Mod; }

inline int qPow(int a, int b = Mod - 2, int ret = 1) {
    while (b) {
        if (b & 1) ret = ret * a % Mod;
        a = a * a % Mod, b >>= 1;
    }
    return ret;
}

// #define dbg

int N, Ans, p[MN], ip[MN], fa[MN];
vector <int> G[MN];

int dep[MN], pa[MN][25], preP[MN], preiP[MN], cP[MN], prP[MN], priP[MN], ciP[MN], sigP[MN], valP[MN], mxd[MN];
int dfn[MN], dfc;
vector <int> vr[MN];
inline void DFS0(int u, int pr) {
    cP[u] = cP[pr];
    if (p[u]) preP[u] = preP[pr] * p[u] % Mod;
    else  preP[u] = 1, cP[u]++;

    ciP[u] = ciP[pr];
    if (ip[u]) prP[u] = prP[pr] * ip[u] % Mod;
    else prP[u] = 1, ciP[u]++;

    sigP[u] = (sigP[pr] * ip[u] % Mod + p[u]) % Mod;
    valP[u] = ((valP[pr] + sigP[pr]) * ip[u] % Mod + p[u]) % Mod;
    
    dfn[u] = ++dfc;
    dep[u] = dep[pa[u][0] = pr] + 1, mxd[u] = dep[u];
    vr[dep[u]].pb(u);
    for (int i = 1; i <= 20; i++)
        pa[u][i] = pa[pa[u][i - 1]][i - 1];
    for (int v : G[u])
        DFS0(v, u), mxd[u] = max(mxd[u], mxd[v]);
}
inline int LCA(int x, int y) {
    if (dep[x] < dep[y]) swap(x, y);
    for (int i = 20; i >= 0; i--)
        if (dep[pa[x][i]] >= dep[y]) x = pa[x][i];
    if (x == y) return x;
    for (int i = 20; i >= 0; i--)
        if (pa[x][i] != pa[y][i]) {
            x = pa[x][i];
            y = pa[y][i];
        }
    return pa[x][0];
}

int stk[MN], tp;
inline bool cmp(int x, int y) {
    return dfn[x] < dfn[y];
}
vector <int> T[MN];
int f[MN][2], g[MN], up[MN], Coef[MN][2], suf[MN], pre[MN], par[MN];
inline int prdiP(int x, int y) {
    if (ciP[y] - ciP[x] > 0) return 0;
    if (ciP[x] > ciP[fa[x]]) return prP[y];
    return prP[y] * priP[x] % Mod;
}
inline int prdP(int x, int y) {
    if (cP[y] - cP[x] > 0) return 0;
    if (cP[x] > cP[fa[x]]) return preP[y];
    return preP[y] * preiP[x] % Mod;
}
inline int QryP(int x, int y) {
    int l = (dep[y] - dep[x]);
    return (valP[y] - (valP[x] + l * sigP[x] % Mod) * prdiP(x, y) % Mod + Mod) % Mod;
}
inline void Build(int d) {
    tp = 0;
    for (int u : vr[d]) stk[++tp] = u;
    stk[++tp] = 1;
    sort(stk + 1, stk + tp + 1, cmp);
    for (int i = 2; i <= tp; i++) {
        int r = LCA(stk[i], stk[i - 1]);
        if (r != stk[i - 1] && r != stk[i]) stk[++tp] = r;
    }
    sort(stk + 1, stk + tp + 1);
    tp = unique(stk + 1, stk + tp + 1) - stk - 1;
    sort(stk + 1, stk + tp + 1, cmp);
    for (int i = 2; i <= tp; i++) {
        int r = LCA(stk[i], stk[i - 1]);
        T[r].pb(stk[i]);

        par[stk[i]] = r;
        Coef[stk[i]][1] = prdiP(r, stk[i]);
        Coef[stk[i]][0] = (1 - Coef[stk[i]][1] + Mod) % Mod;
#ifdef dbg
        printf("%lld : Coef0 = %lld, Coef1 = %lld\n", stk[i], Coef[stk[i]][0], Coef[stk[i]][1]);
#endif
    }
}
inline void DFS1(int u) {
    int c = T[u].size();
    if (!c) {
        f[u][1] = 1, f[u][0] = 0;
        return;
    }
    for (int v : T[u]) DFS1(v);
    pre[0] = 1, suf[c + 1] = 1;
    for (int i = 0; i < c; i++) {
        int v = T[u][i];
        pre[i + 1] = pre[i] * (f[v][0] + f[v][1] * Coef[v][0] % Mod) % Mod;
    }
    for (int i = c - 1; i >= 0; i--) {
        int v = T[u][i];
        suf[i + 1] = suf[i + 2] * (f[v][0] + f[v][1] * Coef[v][0] % Mod) % Mod;
    }
    for (int i = 0; i < c; i++) {
        int v = T[u][i];
        up[v] = pre[i] * Coef[v][1] % Mod * suf[i + 2] % Mod;
        f[u][1] = (f[u][1] + up[v] * f[v][1] % Mod) % Mod;
    }
    f[u][0] = (1 - f[u][1] + Mod) % Mod;
}
inline void DFS2(int u) {
    for (int v : T[u]) 
        g[v] = (g[u] * up[v] % Mod + Coef[v][0]) % Mod, DFS2(v);
}
inline void Work(int d) {   
    Build(d);
    DFS1(1), g[1] = 1, DFS2(1);
    for (int i = 2; i <= tp; i++) {
        int u = stk[i];
        Ans = (Ans + f[u][1] * up[u] % Mod * g[par[u]] % Mod * (dep[u] - dep[par[u]]) % Mod) % Mod;
        if (dep[u] - dep[par[u]] > 1)
            Ans = (Ans + f[u][1] * ip[u] % Mod * QryP(par[u], fa[u]) % Mod) % Mod;
    }
    for (int i = 1; i <= tp; i++)
        T[stk[i]].clear(),
        f[stk[i]][0] = f[stk[i]][1] = g[stk[i]] = up[stk[i]] = 0,
        Coef[stk[i]][0] = Coef[stk[i]][1] = 0;
}
int prod;
inline void getInv() {
    preiP[0] = 1;
    pre[0] = suf[N + 1] = 1;
    for (int i = 1; i <= N; i++) pre[i] = pre[i - 1] * preP[i] % Mod;
    for (int i = N; i >= 1; i--) suf[i] = suf[i + 1] * preP[i] % Mod;
    prod = qPow(pre[N], Mod - 2);
    for (int i = 1; i <= N; i++) 
        preiP[i] = prod * pre[i - 1] % Mod * suf[i + 1] % Mod;
    
    priP[0] = 1;
    pre[0] = suf[N + 1] = 1;
    for (int i = 1; i <= N; i++) pre[i] = pre[i - 1] * prP[i] % Mod;
    for (int i = N; i >= 1; i--) suf[i] = suf[i + 1] * prP[i] % Mod;
    prod = qPow(pre[N], Mod - 2);
    for (int i = 1; i <= N; i++) 
        priP[i] = prod * pre[i - 1] % Mod * suf[i + 1] % Mod;
}

signed main(void) {
    N = read();
    for (int i = 1; i <= N; i++) p[i] = read(), ip[i] = (1 - p[i] + Mod) % Mod;
    for (int i = 2; i <= N; i++) fa[i] = read(), G[fa[i]].pb(i);
    DFS0(1, 0), getInv();
    for (int i = 2; i <= mxd[1]; i++) Work(i);
    printf("%lld\n", Ans); 
    return 0; 
}
posted @ 2022-07-07 15:37  came11ia  阅读(40)  评论(0编辑  收藏  举报