PKUWC2018 minimax
PKUWC2018 minimax
题面描述
一个大小为\(n\)的二叉树,每个叶子结点都有一个互不相同的权值。
每个非叶子结点\(x\)都有一个概率\(p_x\),表示它有\(p_x\)的概率选择它所有儿子权值的最大值,\(1-p_x\)的概率选择它所有儿子权值的最小值。
求出最后根节点取每个权值的概率。
最后把答案以某种方式压缩输出。
答案对\(998244353\)取模。
思路
线段树合并。
如果当前节点为的权值为\(x\),则含\(x\)的子树必须选择\(x\)
要么总体选择最大值,其他子树权值小于\(x\)
要么总体选择最小值,其他子树权值大于\(x\)
维护一个区间和和区间乘法的标记即可。
代码
#include<bits/stdc++.h>
using namespace std;
const int sz=3e5+7;
const int mod=998244353;
int n,m;
int cnt,ans;
int rt[sz];
int f[sz];
int p[sz];
int a[sz];
int inv[sz];
int tr[sz*40],tag[sz*40];
int ls[sz*40],rs[sz*40];
int c[sz][2],t[sz];
void init(){
inv[1]=1;
for(int i=2;i<sz;i++)
inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
}
void update(int &o,int l,int r,int pos,int v){
if(!o) o=++cnt,tag[o]=1;
if(l==r) return (void)(tr[o]=v);
int mid=(l+r)>>1;
if(pos<=mid) update(ls[o],l,mid,pos,v);
else update(rs[o],mid+1,r,pos,v);
tr[o]=(tr[ls[o]]+tr[rs[o]])%mod;
}
void pd(int o){
if(ls[o]){
tag[ls[o]]=1ll*tag[ls[o]]*tag[o]%mod;
tr[ls[o]]=1ll*tr[ls[o]]*tag[o]%mod;
}
if(rs[o]){
tag[rs[o]]=1ll*tag[rs[o]]*tag[o]%mod;
tr[rs[o]]=1ll*tr[rs[o]]*tag[o]%mod;
}
tag[o]=1;
}
int merge(int o1,int o2,int l,int r,int lx1,int rx1,int lx2,int rx2,int x){
if(!o1&&!o2) return 0;
int p1=1ll*p[x]*inv[10000]%mod;
int p2=(mod+1-p1)%mod;
if(!o1){
o1=o1^o2;
int sum=(1ll*lx1*p1%mod+1ll*rx1*p2%mod)%mod;
tr[o1]=1ll*tr[o1]*sum%mod;
tag[o1]=1ll*tag[o1]*sum%mod;
return o1;
}
if(!o2){
o1=o1^o2;
int sum=(1ll*lx2*p1%mod+1ll*rx2*p2%mod)%mod;
tr[o1]=1ll*tr[o1]*sum%mod;
tag[o1]=1ll*tag[o1]*sum%mod;
return o1;
}
if(tag[o1]>1) pd(o1);
if(tag[o2]>1) pd(o2);
int mid=(l+r)>>1;
int suml1=tr[ls[o1]],sumr1=tr[rs[o1]];
int suml2=tr[ls[o2]],sumr2=tr[rs[o2]];
ls[o1]=merge(ls[o1],ls[o2],l,mid,lx1,(rx1+sumr1)%mod,lx2,(rx2+sumr2)%mod,x);
rs[o1]=merge(rs[o1],rs[o2],mid+1,r,(lx1+suml1)%mod,rx1,(lx2+suml2)%mod,rx2,x);
tr[o1]=(tr[ls[o1]]+tr[rs[o1]])%mod;
return o1;
}
void dfs(int x){
if(!t[x]) return (void)(update(rt[x],1,m,p[x],1));
if(c[x][0]) dfs(c[x][0]);
if(c[x][1]) dfs(c[x][1]);
rt[x]=rt[c[x][0]];
if(t[x]==2) rt[x]=merge(rt[x],rt[c[x][1]],1,m,0,0,0,0,x);
}
void getans(int o,int l,int r){
if(l==r) return (void)(ans=(ans+1ll*a[l]*l%mod*tr[o]%mod*tr[o]%mod)%mod);
if(tag[o]>1) pd(o);
int mid=(l+r)>>1;
getans(ls[o],l,mid);
getans(rs[o],mid+1,r);
}
int main(){
init();
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d",&f[i]);
if(i==1) continue;
c[f[i]][t[f[i]]++]=i;
}
for(int i=1;i<=n;i++){
scanf("%d",&p[i]);
if(!t[i]) a[++m]=p[i];
}
sort(a+1,a+m+1);
for(int i=1;i<=n;i++){
if(t[i]) continue;
p[i]=lower_bound(a+1,a+m+1,p[i])-a;
}
dfs(1);
getans(rt[1],1,m);
printf("%d\n",ans);
}
以仁心说,以学心听,以公心辩