P5298 [PKUWC2018] Minimax - 线段树合并
题意
给定一棵 \(n\) 个点的有根树,根为 \(1\),且每个点的儿子个数不超过 \(2\)。每个点都有一个权值,对于点 \(u\),它有 \(p_u\) 的概率使得权值为它的儿子的权值 \(\max\);有 \(1-p_u\) 的概率使得权值为它的儿子的权值 \(\min\)。若点 \(u\) 没有儿子,那么会给定一个权值 \(w_u\)。
设 \(v_i\) 为根节点能取到的第 \(i\) 小的权值,\(p_i\) 为取到它的概率。你要求出 \(\sum_{i=1} i\cdot v_i\cdot p_i^2\bmod 998244353\)。
\(1\le n\le 3\times 10^5,0<p_u<1,1\le w_u\le 10^9\)。
题解
因为 \(0<p_u<1\),所以根节点能取到所有出现过的权值。将所有叶子节点的权值离散化,设 \(f_{u,i}\) 为点 \(u\) 取到离散化后第 \(i\) 小的权值的概率。
设 \(u\) 的左儿子为 \(x\),有转移:
右儿子同理。
考虑线段树合并的过程是怎么保证复杂度的:在 \(\operatorname{merge}(p,q)\) 时,如果 \(p,q\) 都不为空它才会继续往下递归。
所以,对于 \(\operatorname{merge}(p,q)\):
- 假如 \(p\) 为空:此时相当于 \(q\) 这个节点对应的区间内每一个位置都乘上一个数,打区间乘的标记即可。
- 假如 \(q\) 为空:同理。
- 都不为空:递归下去,然后上传新的值。
但转移方程内层的 \(\sum\) 怎么处理?
定义 \([l_x,r_x]\) 为线段树上点 \(x\) 对应的区间。因为 \(\sum\) 是一个前(后)缀的形式,所以在合并过程中记录 \([1,l_p],[1,l_q]\) 这两个区间的值之和即可。
待补:有懒标记的线段树合并方法。
时间复杂度 \(\mathcal{O}(n\log n)\),常数巨大。
代码
看起来我的 \(\operatorname{merge}\) 比别人写得优美一些。
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
#define For(Ti,Ta,Tb) for(int Ti=(Ta);Ti<=(Tb);++Ti)
#define Dec(Ti,Ta,Tb) for(int Ti=(Ta);Ti>=(Tb);--Ti)
#define Debug(...) fprintf(stderr,__VA_ARGS__)
typedef long long ll;
const int N=3e5+5,LogN=20,Mod=998244353;
long long Pow(long long _base,long long _pow,const long long& _mod){
long long _res=1;
while(_pow){
if(_pow&1) _res=_res*_base%_mod;
_pow>>=1,_base=_base*_base%_mod;
}
return _res%_mod;
}
struct SegmentTree{
struct Node{
int l,r,ls,rs;
ll s,Mul=1;
void PushMul(ll k){s=s*k%Mod,Mul=Mul*k%Mod;}
}t[N*LogN];
int tot=0;
int New(int l,int r){
t[++tot].l=l,t[tot].r=r;
return tot;
}
void Pushup(int p){t[p].s=(t[t[p].ls].s+t[t[p].rs].s)%Mod;}
void Pushdown(int p){
if(!p) return;
if(t[p].Mul!=1) t[t[p].ls].PushMul(t[p].Mul),t[t[p].rs].PushMul(t[p].Mul);
t[p].Mul=1;
}
void Modify(int p,int pos,ll k){
if(t[p].l==t[p].r){
t[p].s=k;return;
}
Pushdown(p);
int mid=(t[p].l+t[p].r)/2;
if(pos<=mid) Modify(t[p].ls?t[p].ls:(t[p].ls=New(t[p].l,mid)),pos,k);
else Modify(t[p].rs?t[p].rs:(t[p].rs=New(mid+1,t[p].r)),pos,k);
Pushup(p);
}
int Merge(int p,int q,ll ps,ll qs,ll prob){
if(!p&&!q) return 0;
if(!p||!q){
if(!p) swap(p,q),swap(ps,qs);
t[p].PushMul((qs*prob+(Mod+1-qs)*(Mod+1-prob))%Mod);
return p;
}
Pushdown(p),Pushdown(q);
ll pls=t[t[p].ls].s,qls=t[t[q].ls].s;
t[p].ls=Merge(t[p].ls,t[q].ls,ps,qs,prob);
t[p].rs=Merge(t[p].rs,t[q].rs,(ps+pls)%Mod,(qs+qls)%Mod,prob);
Pushup(p);
return p;
}
void Get(int p,vector<int> &res){
if(!p) return;
if(t[p].l==t[p].r){
res.push_back(t[p].s);return;
}
Pushdown(p);
Get(t[p].ls,res),Get(t[p].rs,res);
}
}seg;
int n,root[N],ch[N][2],discv[N];ll p[N];
vector<ll> disc;
void Dfs(int u){
if(!ch[u][0]){
root[u]=seg.New(1,disc.size()),seg.Modify(root[u],discv[u],1);
}else if(!ch[u][1]){
Dfs(ch[u][0]);
root[u]=root[ch[u][0]];
}else{
Dfs(ch[u][0]),Dfs(ch[u][1]);
root[u]=seg.Merge(root[ch[u][0]],root[ch[u][1]],0,0,p[u]);
}
}
int main(){
#ifndef zyz
ios::sync_with_stdio(false),cin.tie(nullptr);
#endif
cin>>n;
For(i,1,n){
int fa;cin>>fa;
if(fa) ch[fa][ch[fa][0]!=0]=i;
}
ll inv1w=Pow(10000,Mod-2,Mod);
For(i,1,n){
ll x;cin>>x;
if(!ch[i][0]) p[i]=x,disc.push_back(x);
else p[i]=1LL*x*inv1w%Mod;
}
sort(disc.begin(),disc.end());
For(i,1,n) if(!ch[i][0]) discv[i]=lower_bound(disc.begin(),disc.end(),p[i])-disc.begin()+1;
Dfs(1);
vector<int> ans;
ans.push_back(0);
seg.Get(root[1],ans);
ll Ans=0;
for(int i=1;i<int(ans.size());++i){
Ans=(Ans+1LL*i*disc[i-1]%Mod*ans[i]%Mod*ans[i]%Mod)%Mod;
}
cout<<Ans;
return 0;
}