loj#2537. 「PKUWC2018」Minimax

题目链接

loj#2537. 「PKUWC2018」Minimax

题解

\(f_{u,i}\)表示选取i的概率,l为u的左子节点,r为u的子节点
$f_{u,i} = f_{l,i}(p \sum_{j < i} + (1 - p)\sum_{j > i}f_{r,j}) + f_{r,i}(p\sum_{j < i}f_{l,i} + (1 - p)\sum_{j > i}f_{l,j}) $
对于每个节点s维护当前节点所有可能的概率和 ,线段树合并

代码

#include<bits/stdc++.h> 
 
inline int read() { 
	int x = 0,f = 1; 
	char c = getchar(); 
	while(c < '0' || c > '9') c = getchar(); 
	while(c <= '9' && c >= '0') x= x * 10 + c - '0',c = getchar(); 
	return x * f; 
} 
#define LL long long 
const int maxn = 300007; 
const int mod = 998244353; 
const int inv = 796898467; 

int a[maxn];   
int son[maxn][2], fa[maxn]; 
int rt[maxn]; 
int n = 0,m = 0; 
LL s[maxn * 20],tag[maxn * 20],w[maxn],b[maxn],p; 
int lc[maxn * 20],rc[maxn * 20],tot = 0; 

inline void mul(int x,LL t){s[x] = s[x] * t % mod	,tag[x] = tag[x] * t % mod;} 

void push_down(int x) { 
	if(tag[x] == 1) return; 
	mul(lc[x],tag[x]); mul(rc[x],tag[x]); 
	tag[x] = 1; 
} 

void insert(int &x,int l,int r,int rk) { 
	if(!x) x = ++ tot; s[x] = tag[x] = 1; 
	if(l == r) return; 
	int mid = l + r >> 1; 
	if(rk <= mid) insert(lc[x],l,mid,rk); 
	else if(rk > mid) insert(rc[x],mid + 1,r,rk); 
} 
int merge(int x,int y,LL sumx = 0,LL sumy = 0) {
	if(!x) {mul(y,sumx);return y;} 
	if(!y) {mul(x,sumy);return x;} 
	push_down(x);push_down(y); 
	LL x0 = s[lc[x]],x1 = s[rc[x]],y0 = s[lc[y]],y1 = s[rc[y]]; 
	lc[x] = merge(lc[x],lc[y],(sumx + (1 + mod - p) * x1) % mod,(sumy + (1 + mod - p) * y1) % mod); 
	rc[x] = merge(rc[x],rc[y],(sumx + p * x0) % mod,(sumy + p * y0) % mod); 
	s[x] = (s[lc[x]] + s[rc[x]]) % mod;  
	return x; 
} 
int solve(int x) { 
	if(!son[x][0]) { 
		insert(rt[x],1,m,std::lower_bound(b + 1,b + m + 1,w[x]) - b); 
		return rt[x]; 
	} 
	int rtl = solve(son[x][0]); 
	if(!son[x][1]) return rtl; 
	int rtr = solve(son[x][1]); 
	p = w[x]; 
	return merge(rtl,rtr); 
} 
LL calc(int x,int l,int r) { 
	if(l == r) return 1ll * l * b[l] % mod * s[x] % mod * s[x] % mod;
	push_down(x); 
	int mid = l + r >> 1; 
	return (calc(lc[x],l,mid) + calc(rc[x],mid + 1,r)) % mod; 
} 
int main() { 
	n = read();  
	for(int x,i = 1;i <= n;++ i) { 
		x = read(); 
		son[x][0] ? son[x][1] = i : son[x][0] = i; 
	} 
	for(int i = 1;i <= n;++ i)  { 
		 LL x = read(); 
		 son[i][0] ? w[i] = x * inv % mod : w[i] = b[++ m] = x; 
	} 
	std::sort(b + 1,b + m + 1); 
	printf("%lld\n",calc(solve(1),1,m)) ;  
	return 0; 
} 
posted @ 2018-07-22 20:33  zzzzx  阅读(315)  评论(0编辑  收藏  举报