[PKUWC2018]Minimax

I.III.[PKUWC2018]Minimax

看错题+理解错题,成功自闭一整晚

首先,一上来我们就能想到,如果用一个数组来表示每个节点所有可能出现的值及其概率,就会比较轻松。而因为树上父节点的数组是由两个子节点的数组合在一起转移而来的,所以考虑用线段树合并来维护该数组。

显然,没有儿子时转移很轻松;仅有一个儿子时,因为儿子取值唯一,就直接继承儿子的即可;有两个儿子时,观察如果是左儿子上的值 \(u\) 成功转移,则要么是所有右儿子的取值都比它大,且在取 \(\min\);要么是所有右儿子的取值都比它小,且在取 \(\max\);而这是前后缀和的形式。于是我们在线段树合并时维护一大坨前缀和后缀和一类然后就能轻松转移了。

时间复杂度 \(O(n\log n)\)

代码:

#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
const int inv10k=796898467;
int n,rt[300100],p[300100],m,cnt,res;
vector<int>v[300100],u;
int ksm(int x,int y=mod-2){
	int z=1;
	for(;y;y>>=1,x=1ll*x*x%mod)if(y&1)z=1ll*z*x%mod;
	return z;
}
#define mid ((l+r)>>1)
struct SegTree{int lson,rson,sum,tag;}seg[10001000];
void modify(int &x,int l,int r,int P){
	if(l>P||r<P)return;
	if(!x)x=++cnt,seg[x].tag=1;(++seg[x].sum)%=mod;
//	printf("%d %d %d %d\n",x,l,r,P);
	if(l!=r)modify(seg[x].lson,l,mid,P),modify(seg[x].rson,mid+1,r,P);
}
void MUL(int x,int y){seg[x].sum=1ll*seg[x].sum*y%mod,seg[x].tag=1ll*seg[x].tag*y%mod;}
void pushdown(int x){MUL(seg[x].lson,seg[x].tag),MUL(seg[x].rson,seg[x].tag),seg[x].tag=1;}
int merge(int x,int y,int lamxl,int lamxr,int lamyl,int lamyr,int p){
//	printf("(%d %d %d %d %d %d %d)\n",x,y,lamxl,lamxr,lamyl,lamyr,p);
	if(x)pushdown(x);if(y)pushdown(y);
	if(!x){MUL(y,(1ll*p*lamxl%mod+1ll*(mod+1-p)*lamxr%mod)%mod);return y;}
	if(!y){MUL(x,(1ll*p*lamyl%mod+1ll*(mod+1-p)*lamyr%mod)%mod);return x;}
	int z=++cnt;seg[z].tag=1;
	int XL=seg[seg[x].lson].sum,XR=seg[seg[x].rson].sum,YL=seg[seg[y].lson].sum,YR=seg[seg[y].rson].sum;
	seg[z].lson=merge(seg[x].lson,seg[y].lson,lamxl,(lamxr+XR)%mod,lamyl,(lamyr+YR)%mod,p);
	seg[z].rson=merge(seg[x].rson,seg[y].rson,(lamxl+XL)%mod,lamxr,(lamyl+YL)%mod,lamyr,p);
	seg[z].sum=(seg[seg[z].lson].sum+seg[seg[z].rson].sum)%mod;
	return z;
}
void dfs(int x){
	if(v[x].empty())modify(rt[x],1,m,p[x]);
	else if(v[x].size()==1)dfs(v[x][0]),rt[x]=rt[v[x][0]];
	else dfs(v[x][0]),dfs(v[x][1]),rt[x]=merge(rt[v[x][0]],rt[v[x][1]],0,0,0,0,p[x]);
}
void iterate(int x,int l,int r){
	if(l==r)(res+=1ll*l*u[l-1]%mod*seg[x].sum%mod*seg[x].sum%mod)%=mod;
	else pushdown(x),iterate(seg[x].lson,l,mid),iterate(seg[x].rson,mid+1,r);
}
int main(){
	scanf("%d",&n);
	for(int i=1,x;i<=n;i++)scanf("%d",&x),v[x].push_back(i);
	for(int x=1;x<=n;x++){
		scanf("%d",&p[x]);
		if(v[x].empty())u.push_back(p[x]);
		else p[x]=1ll*p[x]*inv10k%mod;
	}
	sort(u.begin(),u.end()),u.resize(m=unique(u.begin(),u.end())-u.begin());
	for(int x=1;x<=n;x++)if(v[x].empty())p[x]=lower_bound(u.begin(),u.end(),p[x])-u.begin()+1;
	dfs(1),iterate(rt[1],1,m);
	printf("%d\n",res);
	return 0;
}

posted @ 2021-04-06 10:22  Troverld  阅读(57)  评论(0编辑  收藏  举报