Luogu P5298 [PKUWC2018]Minimax
有\(p_i\)的概率取最大,\(1-p_i\)的概率取最小。
首先把权值离散化
每个节点开一棵线段树,记录每个权值被取到的概率。
对于线段树\(i\),设两个子树为\(ls,rs\),取到值\(j\)的概率\(f[i][j]\)
\(f[i][j] = \\ f[ls][j]*(\sum\limits_{k=1}^{j-1}f[rs][k]p_i + \sum\limits_{k=j+1}^{m}f[rs][k](1-p_i)) + \\f[rs][j]*(\sum\limits_{k=1}^{j-1}f[ls][k]p_i + \sum\limits_{k=j+1}^{m}f[ls][k](1-p_i))\)
用线段树合并优化,设要合并的两棵树为\(A,B\),
\(A\)的左子树的概率 \(=\) 原来的概率 \(+\) \(B\)的右子树没被取到的概率,
即\(f[A][ls] = sum[A]+f[B][rs]*(1-p)\)
注意数组大小和\(long\ long\)
\(code\)
#include<cstdio>
#include<iostream>
#include<cmath>
#include<cstring>
#include<algorithm>
#define MogeKo qwq
using namespace std;
#define int long long
#define Mid (l+r>>1)
const int maxn = 3e5+10;
const int mod = 998244353;
int n,N,ans,cnt,tot,p[maxn],w[maxn],rt[maxn];
int sum[maxn*40],lazy[maxn*40],ls[maxn*40],rs[maxn*40];
int head[maxn],to[maxn<<1],nxt[maxn<<1];
bool son[maxn];
struct node {
int id,val;
bool operator < (const node &A) const {
return val < A.val;
}
} P[maxn];
int qpow(int a,int b) {
int ans = 1,base = a;
while(b) {
if(b&1) (ans *= base) %= mod;
(base *= base) %= mod;
b >>= 1;
}
return ans;
}
int inv(int x) {
return qpow(x,mod-2);
}
void add(int x,int y) {
to[++cnt] = y;
nxt[cnt] = head[x];
head[x] = cnt;
}
void mul(int now,int x) {
if(!now) return;
(sum[now] *= x) %= mod;
(lazy[now] *= x) %= mod;
}
void pushdown(int now) {
if(lazy[now] == 1) return;
mul(ls[now],lazy[now]);
mul(rs[now],lazy[now]);
lazy[now] = 1;
}
void pushup(int now) {
sum[now] = (sum[ls[now]] + sum[rs[now]]) % mod;
}
int update(int now,int l,int r,int x) {
if(!now) now = ++tot;
lazy[now] = 1;
if(l == r) {
sum[now] = 1;
return now;
}
int mid = Mid;
if(x <= mid) ls[now] = update(ls[now],l,mid,x);
else rs[now] = update(rs[now],mid+1,r,x);
pushup(now);
return now;
}
int merge(int a,int b,int Sum_a,int Sum_b,int p) {
if(!b) {
mul(a,Sum_a);
return a;
}
if(!a) {
mul(b,Sum_b);
return b;
}
pushdown(a), pushdown(b);
int La = (sum[ls[a]] * p) %mod;
int Lb = (sum[ls[b]] * p) %mod;
int Ra = (sum[rs[a]] * (1-p+mod)) %mod;
int Rb = (sum[rs[b]] * (1-p+mod)) %mod;
ls[a] = merge(ls[a], ls[b], (Sum_a+Rb)%mod, (Sum_b+Ra)%mod, p);
rs[a] = merge(rs[a], rs[b], (Sum_a+Lb)%mod, (Sum_b+La)%mod, p);
pushup(a);
return a;
}
void dfs(int u,int fa) {
if(!son[u]) {
rt[u] = update(rt[u],1,N,w[u]);
return;
}
int ch[2] = {0,0};
for(int i = head[u]; i; i = nxt[i]) {
int v = to[i];
if(v == fa) continue;
dfs(v,u);
ch[1] = ch[0],ch[0] = v;
}
int O = 0;
if(ch[1])
rt[u] = merge(rt[ch[0]],rt[ch[1]],O,O,p[u]);
else if(ch[0])
rt[u] = rt[ch[0]];
}
void calc(int now,int l,int r) {
if(l == r) {
(ans += l * P[l].val %mod * sum[now] %mod * sum[now] %mod)%=mod;
return;
}
int mid = Mid;
pushdown(now);
calc(ls[now],l,mid);
calc(rs[now],mid+1,r);
}
signed main() {
scanf("%lld",&n);
int x;
for(int i = 1; i <= n; i++) {
scanf("%lld",&x);
if(!x) continue;
add(i,x), add(x,i);
son[x] = true;
}
int INV = inv(10000);
for(int i = 1; i <= n; i++) {
scanf("%lld",&p[i]);
if(son[i])
(p[i] *= INV) %= mod;
else {
N++;
P[N].id = i;
P[N].val = p[i];
}
}
sort(P+1,P+N+1);
for(int i = 1; i <= N; i++)
w[P[i].id] = i;
dfs(1,0);
calc(rt[1],1,N);
printf("%lld",ans);
return 0;
}