/* 返回顶部 */

Luogu P5298 [PKUWC2018]Minimax

gate

\(p_i\)的概率取最大,\(1-p_i\)的概率取最小。

首先把权值离散化
每个节点开一棵线段树,记录每个权值被取到的概率。
对于线段树\(i\),设两个子树为\(ls,rs\),取到值\(j\)的概率\(f[i][j]\)
\(f[i][j] = \\ f[ls][j]*(\sum\limits_{k=1}^{j-1}f[rs][k]p_i + \sum\limits_{k=j+1}^{m}f[rs][k](1-p_i)) + \\f[rs][j]*(\sum\limits_{k=1}^{j-1}f[ls][k]p_i + \sum\limits_{k=j+1}^{m}f[ls][k](1-p_i))\)

线段树合并优化,设要合并的两棵树为\(A,B\)

\(A\)的左子树的概率 \(=\) 原来的概率 \(+\) \(B\)的右子树没被取到的概率,

\(f[A][ls] = sum[A]+f[B][rs]*(1-p)\)

注意数组大小和\(long\ long\)

\(code\)

#include<cstdio>
#include<iostream>
#include<cmath>
#include<cstring>
#include<algorithm>
#define MogeKo qwq
using namespace std;

#define int long long
#define Mid (l+r>>1)

const int maxn = 3e5+10;
const int mod = 998244353;

int n,N,ans,cnt,tot,p[maxn],w[maxn],rt[maxn];
int sum[maxn*40],lazy[maxn*40],ls[maxn*40],rs[maxn*40];
int head[maxn],to[maxn<<1],nxt[maxn<<1];
bool son[maxn];

struct node {
	int id,val;
	bool operator < (const node &A) const {
		return val < A.val;
	}
} P[maxn];

int qpow(int a,int b) {
	int ans = 1,base = a;
	while(b) {
		if(b&1) (ans *= base) %= mod;
		(base *= base) %= mod;
		b >>= 1;
	}
	return ans;
}

int inv(int x) {
	return qpow(x,mod-2);
}

void add(int x,int y) {
	to[++cnt] = y;
	nxt[cnt] = head[x];
	head[x] = cnt;
}

void mul(int now,int x) {
	if(!now) return;
	(sum[now] *= x) %= mod;
	(lazy[now] *= x) %= mod;
}

void pushdown(int now) {
	if(lazy[now] == 1) return;
	mul(ls[now],lazy[now]);
	mul(rs[now],lazy[now]);
	lazy[now] = 1;
}

void pushup(int now) {
	sum[now] = (sum[ls[now]] + sum[rs[now]]) % mod;
}

int update(int now,int l,int r,int x) {
	if(!now) now = ++tot;
	lazy[now] = 1;
	if(l == r) {
		sum[now] = 1;
		return now;
	}
	int mid = Mid;
	if(x <= mid) ls[now] = update(ls[now],l,mid,x);
	else rs[now] = update(rs[now],mid+1,r,x);
	pushup(now);
	return now;
}

int merge(int a,int b,int Sum_a,int Sum_b,int p) {
	if(!b) {
		mul(a,Sum_a);
		return a;
	}
	if(!a) {
		mul(b,Sum_b);
		return b;
	}
	pushdown(a), pushdown(b);
	int La = (sum[ls[a]] * p) %mod;
	int Lb = (sum[ls[b]] * p) %mod;
	int Ra = (sum[rs[a]] * (1-p+mod)) %mod;
	int Rb = (sum[rs[b]] * (1-p+mod)) %mod;
	ls[a] = merge(ls[a], ls[b], (Sum_a+Rb)%mod, (Sum_b+Ra)%mod, p);
	rs[a] = merge(rs[a], rs[b], (Sum_a+Lb)%mod, (Sum_b+La)%mod, p);
	pushup(a);
	return a;
}

void dfs(int u,int fa) {
	if(!son[u]) {
		rt[u] = update(rt[u],1,N,w[u]);
		return;
	}
	int ch[2] = {0,0};
	for(int i = head[u]; i; i = nxt[i]) {
		int v = to[i];
		if(v == fa) continue;
		dfs(v,u);
		ch[1] = ch[0],ch[0] = v;
	}
	int O = 0;
	if(ch[1])
		rt[u] = merge(rt[ch[0]],rt[ch[1]],O,O,p[u]);
	else if(ch[0])
		rt[u] = rt[ch[0]];
}

void calc(int now,int l,int r) {
	if(l == r) {
		(ans += l * P[l].val %mod * sum[now] %mod * sum[now] %mod)%=mod;
		 return;
	}
	int mid = Mid;
	pushdown(now);
	calc(ls[now],l,mid);
	calc(rs[now],mid+1,r);
}

signed main() {
	scanf("%lld",&n);
	int x;
	for(int i = 1; i <= n; i++) {
		scanf("%lld",&x);
		if(!x) continue;
		add(i,x), add(x,i);
		son[x] = true;
	}
	int INV = inv(10000);
	for(int i = 1; i <= n; i++) {
		scanf("%lld",&p[i]);
		if(son[i])
			(p[i] *= INV) %= mod;
		else {
			N++;
			P[N].id = i;
			P[N].val = p[i];
		}
	}
	sort(P+1,P+N+1);
	for(int i = 1; i <= N; i++)
		w[P[i].id] = i;
	dfs(1,0);
	calc(rt[1],1,N);
	printf("%lld",ans);
	return 0;
}

posted @ 2020-07-30 22:44  Mogeko  阅读(149)  评论(0编辑  收藏  举报