BZOJ5461: [PKUWC2018]Minimax

BZOJ5461: [PKUWC2018]Minimax

https://lydsy.com/JudgeOnline/problem.php?id=5461

分析:

  • 写出\(dp\)式子:$ f[x][i] = sum f[ls][i]\times p\times sum1[rs]j + f[ls][i]\times (1-p)\times sum2[rs]j$
  • 这玩意能用线段树合并优化。
  • 具体地,我们考虑线段树上维护答案,那么对于合并过程中\(x,y\)两课子树,如果出现某一棵为空的情况,对于另一棵需要乘的值是相同的,此时打标记即可。
  • 然后分析\(ls[x],ls[y],rs[x],rs[y]\)互相的贡献即可。

代码:

//f[x][i] = sum f[ls][i]*p*sum1[rs][j](i>j) + f[ls][i]*(1-p)*sum2[rs][j](i<j)
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cstdlib>
using namespace std;
#define N 300050
#define mod 998244353
typedef long long ll;
int n,ch[N][2],a[N],cnt,V[N],koishi,root[N];
ll sum[N*20],tag[N*20],ans;
int ls[N*20],rs[N*20];
ll qp(ll x,ll y) {
    ll re=1;
    for(;y;y>>=1,x=x*x%mod) if(y&1) re=re*x%mod; return re;
}
const ll inv10000=qp(10000,mod-2);
inline void pushup(int p) {sum[p]=(sum[ls[p]]+sum[rs[p]])%mod;}
inline void giv(int p,ll d) {
    tag[p]=tag[p]*d%mod; sum[p]=sum[p]*d%mod;
}
inline void pushdown(int p) {
    if(tag[p]!=1) {
        if(ls[p]) giv(ls[p],tag[p]);
        if(rs[p]) giv(rs[p],tag[p]);
        tag[p]=1;
    }
}
void update(int l,int r,int x,int &p) {
    p=++koishi; tag[p]=sum[p]=1;
    if(l==r) return ;
    int mid=(l+r)>>1;
    if(x<=mid) update(l,mid,x,ls[p]);
    else update(mid+1,r,x,rs[p]);
}
int merge(int x,int y,ll gy,ll gx,ll pw) {
    // if(!x&&!y) return 0;
    if(!x) {giv(y,gy); return y;}
    if(!y) {giv(x,gx); return x;}
    pushdown(x),pushdown(y);
    ll rsx=sum[rs[x]],rsy=sum[rs[y]],lsx=sum[ls[x]],lsy=sum[ls[y]];
    ls[x]=merge(ls[x],ls[y],(gy+(1-pw)*rsx)%mod,(gx+(1-pw)*rsy)%mod,pw);
    rs[x]=merge(rs[x],rs[y],(gy+pw*lsx)%mod,(gx+pw*lsy)%mod,pw);
    pushup(x);
    return x;
}
void dfs(int x) {
    if(!ch[x][0]&&!ch[x][1]) {
        update(1,cnt,a[x],root[x]);
    }else if(!ch[x][1]) {
        dfs(ch[x][0]);
        root[x]=root[ch[x][0]];
    }else {
        dfs(ch[x][0]), dfs(ch[x][1]);
        root[x]=merge(root[ch[x][0]],root[ch[x][1]],0ll,0ll,a[x]*inv10000%mod);
    }
}
void solve(int l,int r,int p) {
    if(l==r) {
        ans=(ans+ll(l)*V[l]%mod*sum[p]%mod*sum[p])%mod;
        return ;
    }
    pushdown(p);
    int mid=(l+r)>>1;
    if(ls[p]) solve(l,mid,ls[p]);
    if(rs[p]) solve(mid+1,r,rs[p]);
}
int main() {
    scanf("%d",&n);
    int i,x;
    for(i=1;i<=n;i++) {
        scanf("%d",&x);
        if(i==1) continue;
        if(!ch[x][0]) ch[x][0]=i;
        else ch[x][1]=i;
    }
    for(i=1;i<=n;i++) {
        scanf("%d",&a[i]);
        if(!ch[i][0]&&!ch[i][1]) {
            V[++cnt]=a[i];
        }
    }
    sort(V+1,V+cnt+1);
    for(i=1;i<=n;i++) {
        if(!ch[i][0]&&!ch[i][1]) {
            a[i]=lower_bound(V+1,V+cnt+1,a[i])-V;
        }
    }
    dfs(1);
    solve(1,cnt,root[1]);
    printf("%lld\n",(ans+mod)%mod);
}
posted @ 2018-12-09 20:54  fcwww  阅读(323)  评论(0编辑  收藏  举报