[PKUWC2018] Minimax
前言
书接上文
题目
讲解
可以发现和上一道题十分相似。
这道题的转移虽然前缀、后缀都要用,但是因为每个点至多两个儿子,而且没有相同的值,代码会简单很多,也会少一些分类讨论,于是就是个很简单的线段树合并了。
注意转移的时候应该用未更新的前缀、后缀来更新,而非更新了左子树之后再用左子树的值来更新右子树。
时空复杂度 \(O(n\log_2W)\),因为太懒就没有离散化了。
代码
//12252024832524
#include <bits/stdc++.h>
#define TT template<typename T>
using namespace std;
typedef long long LL;
const int MAXN = 300005;
const int MOD = 998244353;
const int inv = 796898467;//10000
int n,p[MAXN];
LL Read()
{
LL x = 0,f = 1;char c = getchar();
while(c > '9' || c < '0'){if(c == '-')f = -1;c = getchar();}
while(c >= '0' && c <= '9'){x = (x*10) + (c^48);c = getchar();}
return x * f;
}
TT void Put1(T x)
{
if(x > 9) Put1(x/10);
putchar(x%10^48);
}
TT void Put(T x,char c = -1)
{
if(x < 0) putchar('-'),x = -x;
Put1(x); if(c >= 0) putchar(c);
}
TT T Max(T x,T y){return x > y ? x : y;}
TT T Min(T x,T y){return x < y ? x : y;}
TT T Abs(T x){return x < 0 ? -x : x;}
#define lc (t[x].ch[0])
#define rc (t[x].ch[1])
int rt[MAXN],tot;
struct node{
int ch[2],s,mul;
}t[MAXN*40];
void calc(int x,int val){
if(!x) return;
t[x].s = 1ll * t[x].s * val % MOD;
t[x].mul = 1ll * t[x].mul * val % MOD;
}
void down(int x){
if(t[x].mul == 1) return;
calc(lc,t[x].mul); calc(rc,t[x].mul);
t[x].mul = 1;
}
void Add(int &x,int l,int r,int pos){
x = ++tot; t[x].s = t[x].mul = 1;
if(l == r) return;
int mid = (l+r) >> 1;
if(pos <= mid) Add(lc,l,mid,pos);
else Add(rc,mid+1,r,pos);
}
int ls[MAXN],rs[MAXN];
int mge(int x,int y,int s1,int s2,int P){
if(!x && !y) return 0;
if(!x || !y){
if(!x){calc(y,s2);return y;}
else {calc(x,s1);return x;}
}
//distinct value!!! no more discussion!!!
down(x); down(y);
int lx = t[lc].s,rx = t[rc].s,ly = t[t[y].ch[0]].s,ry = t[t[y].ch[1]].s;
lc = mge(lc,t[y].ch[0],(s1 + ry*(MOD+1ll-P))%MOD,(s2 + rx*(MOD+1ll-P))%MOD,P);
rc = mge(rc,t[y].ch[1],(s1 + 1ll*ly*P)%MOD,(s2 + 1ll*lx*P)%MOD,P);
t[x].s = (t[lc].s + t[rc].s) % MOD;
return x;
}
void dfs(int x){
if(!ls[x] && !rs[x]) Add(rt[x],1,1e9,p[x]);
else{
dfs(ls[x]);
if(rs[x]) dfs(rs[x]);
if(ls[x] && rs[x]) rt[x] = mge(rt[ls[x]],rt[rs[x]],0,0,1ll*p[x]*inv%MOD);
else rt[x] = rt[ls[x]];
}
}
int cnt,ans;
void solve(int x,int l,int r){
if(!x) return;
if(l == r){
++cnt;
ans = (ans + 1ll * cnt * l % MOD * t[x].s % MOD * t[x].s) % MOD;
return;
}
int mid = (l+r) >> 1;
down(x);
if(lc) solve(lc,l,mid);
if(rc) solve(rc,mid+1,r);
}
int main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
n = Read();
for(int i = 1;i <= n;++ i){
int fa = Read();
if(fa){
if(!ls[fa]) ls[fa] = i;
else rs[fa] = i;
}
}
for(int i = 1;i <= n;++ i) p[i] = Read();
dfs(1);
solve(rt[1],1,1e9);
Put(ans,'\n');
return 0;
}