loj#2537. 「PKUWC2018」Minimax
题目链接
题解
设\(f_{u,i}\)表示选取i的概率,l为u的左子节点,r为u的子节点
$f_{u,i} = f_{l,i}(p \sum_{j < i} + (1 - p)\sum_{j > i}f_{r,j}) + f_{r,i}(p\sum_{j < i}f_{l,i} + (1 - p)\sum_{j > i}f_{l,j}) $
对于每个节点s维护当前节点所有可能的概率和 ,线段树合并
代码
#include<bits/stdc++.h>
inline int read() {
int x = 0,f = 1;
char c = getchar();
while(c < '0' || c > '9') c = getchar();
while(c <= '9' && c >= '0') x= x * 10 + c - '0',c = getchar();
return x * f;
}
#define LL long long
const int maxn = 300007;
const int mod = 998244353;
const int inv = 796898467;
int a[maxn];
int son[maxn][2], fa[maxn];
int rt[maxn];
int n = 0,m = 0;
LL s[maxn * 20],tag[maxn * 20],w[maxn],b[maxn],p;
int lc[maxn * 20],rc[maxn * 20],tot = 0;
inline void mul(int x,LL t){s[x] = s[x] * t % mod ,tag[x] = tag[x] * t % mod;}
void push_down(int x) {
if(tag[x] == 1) return;
mul(lc[x],tag[x]); mul(rc[x],tag[x]);
tag[x] = 1;
}
void insert(int &x,int l,int r,int rk) {
if(!x) x = ++ tot; s[x] = tag[x] = 1;
if(l == r) return;
int mid = l + r >> 1;
if(rk <= mid) insert(lc[x],l,mid,rk);
else if(rk > mid) insert(rc[x],mid + 1,r,rk);
}
int merge(int x,int y,LL sumx = 0,LL sumy = 0) {
if(!x) {mul(y,sumx);return y;}
if(!y) {mul(x,sumy);return x;}
push_down(x);push_down(y);
LL x0 = s[lc[x]],x1 = s[rc[x]],y0 = s[lc[y]],y1 = s[rc[y]];
lc[x] = merge(lc[x],lc[y],(sumx + (1 + mod - p) * x1) % mod,(sumy + (1 + mod - p) * y1) % mod);
rc[x] = merge(rc[x],rc[y],(sumx + p * x0) % mod,(sumy + p * y0) % mod);
s[x] = (s[lc[x]] + s[rc[x]]) % mod;
return x;
}
int solve(int x) {
if(!son[x][0]) {
insert(rt[x],1,m,std::lower_bound(b + 1,b + m + 1,w[x]) - b);
return rt[x];
}
int rtl = solve(son[x][0]);
if(!son[x][1]) return rtl;
int rtr = solve(son[x][1]);
p = w[x];
return merge(rtl,rtr);
}
LL calc(int x,int l,int r) {
if(l == r) return 1ll * l * b[l] % mod * s[x] % mod * s[x] % mod;
push_down(x);
int mid = l + r >> 1;
return (calc(lc[x],l,mid) + calc(rc[x],mid + 1,r)) % mod;
}
int main() {
n = read();
for(int x,i = 1;i <= n;++ i) {
x = read();
son[x][0] ? son[x][1] = i : son[x][0] = i;
}
for(int i = 1;i <= n;++ i) {
LL x = read();
son[i][0] ? w[i] = x * inv % mod : w[i] = b[++ m] = x;
}
std::sort(b + 1,b + m + 1);
printf("%lld\n",calc(solve(1),1,m)) ;
return 0;
}