P5298 [PKUWC2018]Minimax

P5298 [PKUWC2018]Minimax

首先考虑最简单的 \(\text{dp}\) 式子。

\(dp_{x,j}\) 表示当前在点 \(x\),且点 \(x\) 的权值为 \(j\) 的概率。由于 \(n \le 3\times 10^5\),考虑将题目给出的权值离散化。

由于一个点只有两个儿子,考虑点的转移方程。设 \(f_j\) 表示 \(x\) 左儿子的 dp 值,即 \(dp_{ls_x,j}\)\(g_j\) 表示 \(x\) 右儿子的 \(\text{dp}\) 值。(由于保证每个节点的权值互不相同,所以不需要考虑两个儿子 \(\text{dp}\) 值相同的情况)

  • 若当前点的权值从左儿子转移

    • 若当前权值是左儿子的最大值,则概率为 \(f_j \times p \times \sum_{i=1}^{j-1} g_i\)
    • 若当前权值是左儿子的最小值,则概率为 \(f_j\times (1-p)\times \sum_{i=j+1}^{Max} g_i\)
  • 若当前点的权值从右儿子转移

    • 若当前权值是右儿子的最大值,则概率为 \(g_j \times p \times \sum_{i=1}^{j-1} f_i\)

    • 若当前权值是右儿子的最小值,则概率为 \(g_j\times (1-p)\times \sum_{i=j+1}^{Max} f_i\)

综上,将上式全部相加即可得到当前点的转移方程

\[dp_{x,j}=f_j \times (p \times \sum_{i=1}^{j-1} g_i+(1-p)\times \sum_{i=j+1}^{Max} g_i)+g_j \times (p \times \sum_{i=1}^{j-1} f_i+(1-p)\times \sum_{i=j+1}^{Max} f_i) \]

直接转移可以得到 \(\mathcal{O}(n^2)\) 的做法。


考虑优化,注意到式子中有多个前缀后缀和的形式,考虑使用权值线段树维护 \(dp\) 的第二维 \(j\),向上转移时将线段树合并。

当将以 \(x,y\) 为根的两颗线段树合并时,设当前区间为 \([l,r]\),在 \(\text{merge}\) 过程中记录

\[lx=\sum_{i=1}^{l-1} dp_{x,i}\\ rx=\sum_{i=r+1}^{Max} dp_{x,i}\\ ly=\sum_{i=1}^{l-1} dp_{y,i}\\ ry=\sum_{i=r+1}^{Max} dp_{y,i}\\ \]

分以下情况讨论:

  • \(x,y\) 均为空时,直接返回即可。
  • \(x\) 为空,\(y\) 不为空时,则 \(\forall i(l\le i\le n,i \in N)\),均有 \(dp_{x,i}=0\)。此时,上面的转移方程可以化为 \(dp_{x,j}=g_j \times (p \times \sum_{i=1}^{j-1} f_i+(1-p)\times \sum_{i=j+1}^{Max} f_i)\),则对于区间 \([l,r]\) 中的任何一个 \(j\),均满足 \(f_{l},f_{l+1},\dots,f_j,\dots,f_{r-1},f_r=0\),则上式又可以化为 \(dp_{x,j}=g_j \times (p \times \sum_{i=1}^{l-1} f_i+(1-p)\times \sum_{i=r+1}^{Max} f_i)\)。使用 \(lx,rx\) 表示则可以得到 \(dp_{x,j}=g_j \times (p \times lx+(1-p)\times rx)\),即右侧与 \(j\) 无关。则对于所有的 \(j\) 满足 \(j \in [l,r]\),都相当于在原来 \(y\) 树的基础上乘了 \((p \times lx+(1-p)\times rx)\)。直接转移打标记即可。
  • \(y\) 为空,\(x\) 不为空时,与上面一种情况类似。
  • \(x,y\) 均不为空时,维护接下来的 \(lx^{\prime},rx^{\prime},ly^{\prime},ry^{\prime}\) 并向两边同时递归即可。

综上,可以得到时间复杂度为 \(\mathcal{O}(n \log n)\),空间复杂度为 \(\mathcal{O}(n \log n)\) 的做法。

code
#include<bits/stdc++.h>
using namespace std;
namespace IO{
	template<typename T>inline bool read(T &x){
		x=0;
		char ch=getchar();
		bool flag=0,ret=0;
		while(ch<'0'||ch>'9') flag=flag||(ch=='-'),ch=getchar();
		while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar(),ret=1;
		x=flag?-x:x;
        return ret;
	}
	template<typename T,typename ...Args>inline bool read(T& a,Args& ...args){
	    return read(a)&&read(args...);
	}
	template<typename T>void prt(T x){
		if(x>9) prt(x/10);
		putchar(x%10+'0');
	}
	template<typename T>inline void put(T x){
		if(x<0) putchar('-'),x=-x;
		prt(x);
	}
	template<typename T>inline void put(char ch,T x){
		if(x<0) putchar('-'),x=-x;
		prt(x);
		putchar(ch);
	}
	template<typename T,typename ...Args>inline void put(T a,Args ...args){
	    put(a);
		put(args...);
	}
	template<typename T,typename ...Args>inline void put(const char ch,T a,Args ...args){
	    put(ch,a);
		put(ch,args...);
	}
	inline void put(string s){
		for(int i=0,sz=s.length();i<sz;i++) putchar(s[i]);
	}
	inline void put(const char* s){
		for(int i=0,sz=strlen(s);i<sz;i++) putchar(s[i]);
	}
}
using namespace IO;
#define N 300005
#define mod 998244353
#define ll long long
inline int power(int x,int y){
	int res=1;
	while(y){
		if(y&1) res=(ll)res*x%mod;
		x=(ll)x*x%mod;
		y>>=1;
	}
	return res;
}
int n,son[N][2],num[N],value[N],p[N],b[N],Idx,rt[N],res[N],ans,idx;
struct node{
	int ls,rs,sum,tag;
}t[N*25];
#define lc(x) t[x].ls
#define rc(x) t[x].rs
inline void push_down(int x){
	if(t[x].tag==1) return;
	t[lc(x)].sum=(ll)t[lc(x)].sum*t[x].tag%mod;
	t[rc(x)].sum=(ll)t[rc(x)].sum*t[x].tag%mod;
	t[lc(x)].tag=(ll)t[lc(x)].tag*t[x].tag%mod;
	t[rc(x)].tag=(ll)t[rc(x)].tag*t[x].tag%mod;
	t[x].tag=1;
}
inline void push_up(int x){
	t[x].sum=(t[lc(x)].sum+t[rc(x)].sum)%mod;
}
inline void update(int &x,int l,int r,int pos,int pro){
	if(!x) t[x=++idx].tag=1;
	if(l==r) return t[x].sum=pro,void();
	int mid=l+r>>1;
	if(pos<=mid) update(lc(x),l,mid,pos,pro);
	else update(rc(x),mid+1,r,pos,pro);
	push_up(x);
}
inline int merge(int x,int y,int lx,int rx,int ly,int ry,int pro){
	if(!x&&!y) return 0;
	push_down(x),push_down(y);
	int xmul=((ll)pro*ly%mod+(ll)(1-pro+mod)*ry%mod)%mod;
	int ymul=((ll)pro*lx%mod+(ll)(1-pro+mod)*rx%mod)%mod;
	if(!x){
		t[y].sum=(ll)t[y].sum*ymul%mod;
		t[y].tag=(ll)t[y].tag*ymul%mod;
		return y;
	}
	if(!y){
		t[x].sum=(ll)t[x].sum*xmul%mod;
		t[x].tag=(ll)t[x].tag*xmul%mod;
		return x;
	}
	int ax=t[lc(x)].sum,bx=t[rc(x)].sum,ay=t[lc(y)].sum,by=t[rc(y)].sum;
	lc(x)=merge(lc(x),lc(y),lx,(rx+bx)%mod,ly,(ry+by)%mod,pro);
	rc(x)=merge(rc(x),rc(y),(lx+ax)%mod,rx,(ly+ay)%mod,ry,pro);
	push_up(x);
	return x;
}
void dfs(int x){
	if(!num[x]) update(rt[x],1,Idx,value[x],1);
	else if(num[x]==1) dfs(son[x][0]),rt[x]=rt[son[x][0]];
	else{
		dfs(son[x][0]),dfs(son[x][1]);
		rt[x]=merge(rt[son[x][0]],rt[son[x][1]],0,0,0,0,p[x]);
	}
}
inline void getans(int x,int l,int r){
	if(!x) return;
	if(l==r) return res[l]=t[x].sum,void();
	int mid=l+r>>1;push_down(x);
	getans(lc(x),l,mid),getans(rc(x),mid+1,r);
}
int main(){
	read(n);
	for(int i=1,x;i<=n;i++){
		read(x);
		if(i==1) continue;
		son[x][son[x][0]!=0]=i;
		num[x]++;
	}
	for(int i=1,x;i<=n;i++){
		read(x);
		if(!num[i]) value[i]=x,b[++Idx]=x;
		else p[i]=(ll)x*power(10000,mod-2)%mod; 
	}
	sort(b+1,b+Idx+1);
	Idx=unique(b+1,b+Idx+1)-b-1;
	for(int i=1;i<=n;i++)
		if(value[i]) value[i]=lower_bound(b+1,b+Idx+1,value[i])-b;
	dfs(1);
	getans(rt[1],1,Idx);
	for(int i=1;i<=Idx;i++)
		ans=(ans+(ll)i*b[i]%mod*res[i]%mod*res[i]%mod)%mod;
	put('\n',ans);
	return 0;
}

posted @ 2022-10-06 15:15  fzj2007  阅读(19)  评论(0编辑  收藏  举报