PKUWC2018 minimax

PKUWC2018 minimax

题面描述

一个大小为\(n\)的二叉树,每个叶子结点都有一个互不相同的权值。

每个非叶子结点\(x\)都有一个概率\(p_x\),表示它有\(p_x\)的概率选择它所有儿子权值的最大值,\(1-p_x\)的概率选择它所有儿子权值的最小值。

求出最后根节点取每个权值的概率。

最后把答案以某种方式压缩输出。

答案对\(998244353\)取模。

思路

线段树合并。

如果当前节点为的权值为\(x\),则含\(x\)的子树必须选择\(x\)

要么总体选择最大值,其他子树权值小于\(x\)

要么总体选择最小值,其他子树权值大于\(x\)

维护一个区间和和区间乘法的标记即可。

代码

#include<bits/stdc++.h>
using namespace std;
const int sz=3e5+7;
const int mod=998244353;
int n,m;
int cnt,ans;
int rt[sz];
int f[sz];
int p[sz];
int a[sz];
int inv[sz];
int tr[sz*40],tag[sz*40];
int ls[sz*40],rs[sz*40];
int c[sz][2],t[sz];
void init(){
	inv[1]=1;
	for(int i=2;i<sz;i++)
		inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
}
void update(int &o,int l,int r,int pos,int v){
	if(!o) o=++cnt,tag[o]=1;
	if(l==r) return (void)(tr[o]=v);
	int mid=(l+r)>>1;
	if(pos<=mid) update(ls[o],l,mid,pos,v);
	else update(rs[o],mid+1,r,pos,v);
	tr[o]=(tr[ls[o]]+tr[rs[o]])%mod;
}
void pd(int o){
	if(ls[o]){
		tag[ls[o]]=1ll*tag[ls[o]]*tag[o]%mod;
		tr[ls[o]]=1ll*tr[ls[o]]*tag[o]%mod;
	}
	if(rs[o]){
		tag[rs[o]]=1ll*tag[rs[o]]*tag[o]%mod;
		tr[rs[o]]=1ll*tr[rs[o]]*tag[o]%mod;
	}
	tag[o]=1;
}
int merge(int o1,int o2,int l,int r,int lx1,int rx1,int lx2,int rx2,int x){
	if(!o1&&!o2) return 0;
	int p1=1ll*p[x]*inv[10000]%mod;
	int p2=(mod+1-p1)%mod;
	if(!o1){
		o1=o1^o2;
		int sum=(1ll*lx1*p1%mod+1ll*rx1*p2%mod)%mod;
		tr[o1]=1ll*tr[o1]*sum%mod;
		tag[o1]=1ll*tag[o1]*sum%mod;
		return o1;
	}
	if(!o2){
		o1=o1^o2;
		int sum=(1ll*lx2*p1%mod+1ll*rx2*p2%mod)%mod;
		tr[o1]=1ll*tr[o1]*sum%mod;
		tag[o1]=1ll*tag[o1]*sum%mod;
		return o1;
	}
	if(tag[o1]>1) pd(o1);
	if(tag[o2]>1) pd(o2);
	int mid=(l+r)>>1;
	int suml1=tr[ls[o1]],sumr1=tr[rs[o1]];
	int suml2=tr[ls[o2]],sumr2=tr[rs[o2]];
	ls[o1]=merge(ls[o1],ls[o2],l,mid,lx1,(rx1+sumr1)%mod,lx2,(rx2+sumr2)%mod,x);
	rs[o1]=merge(rs[o1],rs[o2],mid+1,r,(lx1+suml1)%mod,rx1,(lx2+suml2)%mod,rx2,x);
	tr[o1]=(tr[ls[o1]]+tr[rs[o1]])%mod;
	return o1;
}
void dfs(int x){
	if(!t[x]) return (void)(update(rt[x],1,m,p[x],1));
	if(c[x][0]) dfs(c[x][0]);
	if(c[x][1]) dfs(c[x][1]);
	rt[x]=rt[c[x][0]];
	if(t[x]==2) rt[x]=merge(rt[x],rt[c[x][1]],1,m,0,0,0,0,x);
}
void getans(int o,int l,int r){
	if(l==r) return (void)(ans=(ans+1ll*a[l]*l%mod*tr[o]%mod*tr[o]%mod)%mod);
	if(tag[o]>1) pd(o);
	int mid=(l+r)>>1;
	getans(ls[o],l,mid);
	getans(rs[o],mid+1,r);
}
int main(){
	init();
	scanf("%d",&n);
	for(int i=1;i<=n;i++){
		scanf("%d",&f[i]);
		if(i==1) continue;
		c[f[i]][t[f[i]]++]=i;
	}
	for(int i=1;i<=n;i++){
		scanf("%d",&p[i]);
		if(!t[i]) a[++m]=p[i];
	}
	sort(a+1,a+m+1);
	for(int i=1;i<=n;i++){
		if(t[i]) continue;
		p[i]=lower_bound(a+1,a+m+1,p[i])-a;
	}
	dfs(1);
	getans(rt[1],1,m);
	printf("%d\n",ans);
}
posted @ 2019-12-04 17:17  霞光  阅读(253)  评论(0编辑  收藏  举报