BZOJ5461: [PKUWC2018]Minimax
BZOJ5461: [PKUWC2018]Minimax
https://lydsy.com/JudgeOnline/problem.php?id=5461
分析:
- 写出\(dp\)式子:$ f[x][i] = sum f[ls][i]\times p\times sum1[rs]j + f[ls][i]\times (1-p)\times sum2[rs]j$
- 这玩意能用线段树合并优化。
- 具体地,我们考虑线段树上维护答案,那么对于合并过程中\(x,y\)两课子树,如果出现某一棵为空的情况,对于另一棵需要乘的值是相同的,此时打标记即可。
- 然后分析\(ls[x],ls[y],rs[x],rs[y]\)互相的贡献即可。
代码:
//f[x][i] = sum f[ls][i]*p*sum1[rs][j](i>j) + f[ls][i]*(1-p)*sum2[rs][j](i<j)
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cstdlib>
using namespace std;
#define N 300050
#define mod 998244353
typedef long long ll;
int n,ch[N][2],a[N],cnt,V[N],koishi,root[N];
ll sum[N*20],tag[N*20],ans;
int ls[N*20],rs[N*20];
ll qp(ll x,ll y) {
ll re=1;
for(;y;y>>=1,x=x*x%mod) if(y&1) re=re*x%mod; return re;
}
const ll inv10000=qp(10000,mod-2);
inline void pushup(int p) {sum[p]=(sum[ls[p]]+sum[rs[p]])%mod;}
inline void giv(int p,ll d) {
tag[p]=tag[p]*d%mod; sum[p]=sum[p]*d%mod;
}
inline void pushdown(int p) {
if(tag[p]!=1) {
if(ls[p]) giv(ls[p],tag[p]);
if(rs[p]) giv(rs[p],tag[p]);
tag[p]=1;
}
}
void update(int l,int r,int x,int &p) {
p=++koishi; tag[p]=sum[p]=1;
if(l==r) return ;
int mid=(l+r)>>1;
if(x<=mid) update(l,mid,x,ls[p]);
else update(mid+1,r,x,rs[p]);
}
int merge(int x,int y,ll gy,ll gx,ll pw) {
// if(!x&&!y) return 0;
if(!x) {giv(y,gy); return y;}
if(!y) {giv(x,gx); return x;}
pushdown(x),pushdown(y);
ll rsx=sum[rs[x]],rsy=sum[rs[y]],lsx=sum[ls[x]],lsy=sum[ls[y]];
ls[x]=merge(ls[x],ls[y],(gy+(1-pw)*rsx)%mod,(gx+(1-pw)*rsy)%mod,pw);
rs[x]=merge(rs[x],rs[y],(gy+pw*lsx)%mod,(gx+pw*lsy)%mod,pw);
pushup(x);
return x;
}
void dfs(int x) {
if(!ch[x][0]&&!ch[x][1]) {
update(1,cnt,a[x],root[x]);
}else if(!ch[x][1]) {
dfs(ch[x][0]);
root[x]=root[ch[x][0]];
}else {
dfs(ch[x][0]), dfs(ch[x][1]);
root[x]=merge(root[ch[x][0]],root[ch[x][1]],0ll,0ll,a[x]*inv10000%mod);
}
}
void solve(int l,int r,int p) {
if(l==r) {
ans=(ans+ll(l)*V[l]%mod*sum[p]%mod*sum[p])%mod;
return ;
}
pushdown(p);
int mid=(l+r)>>1;
if(ls[p]) solve(l,mid,ls[p]);
if(rs[p]) solve(mid+1,r,rs[p]);
}
int main() {
scanf("%d",&n);
int i,x;
for(i=1;i<=n;i++) {
scanf("%d",&x);
if(i==1) continue;
if(!ch[x][0]) ch[x][0]=i;
else ch[x][1]=i;
}
for(i=1;i<=n;i++) {
scanf("%d",&a[i]);
if(!ch[i][0]&&!ch[i][1]) {
V[++cnt]=a[i];
}
}
sort(V+1,V+cnt+1);
for(i=1;i<=n;i++) {
if(!ch[i][0]&&!ch[i][1]) {
a[i]=lower_bound(V+1,V+cnt+1,a[i])-V;
}
}
dfs(1);
solve(1,cnt,root[1]);
printf("%lld\n",(ans+mod)%mod);
}