[BZOJ5461][LOJ#2537[PKUWC2018]Minimax(概率DP+线段树合并)
还是没有弄清楚线段树合并的时间复杂度是怎么保证的,就当是$O(m\log n)$吧。
这题有一个显然的DP,dp[i][j]表示节点i的值为j的概率,转移时维护前缀后缀和,将4项加起来就好了。
这个感觉已经很难做到比$O(n^2)$更优的复杂度了,但我们要看到题目里有什么条件没用上:每个节点最多有2个儿子。
这个提醒我们可以用启发式合并,据说splay可以做,但我们可以考虑一下线段树合并做法。
仍然采用上面的转移方程,这里线段树上的一个节点T[x]表示x表示的区间[L,R]最终成为当前子树的根的值的概率,那么答案显然可以通过最终线段树上每个叶子节点统计。现在难点就在于如何统计。
考虑某个左子树的动态开店权值线段树上某个节点x表示的区间[L,R]最后成为根rt的值的概率,它是由p[rt]*(右子树的值小于R的概率)+(1-p[rt])*(右子树的值小于L的概率),这个其实是可以在merge的过程中递归下去累加的。
具体做法是:merge(x,y,sx,sy)表示合并线段树节点x和y(显然这里x和y分属左右子树,表示的是同一个权值区间),那么我们每次将这一层的信息累加进去,接着递归下去即可。当我们发现只存在一棵子树了(设为x),那么这棵子树的概率要乘上sy。
1 #include<cstdio> 2 #include<cstring> 3 #include<algorithm> 4 #define rep(i,l,r) for (int i=l; i<=r; i++) 5 typedef long long ll; 6 using namespace std; 7 8 const int N=300010,M=6000010,g=796898467,mod=998244353; 9 int n,nd,p,tot,w[N],rt[N],fa[N],son[N][2],b[N],s[M],tag[M],ls[M],rs[M]; 10 11 void put(int x,int k){ s[x]=1ll*s[x]*k%mod; tag[x]=1ll*tag[x]*k%mod; } 12 void push(int x){ if (tag[x]!=1) put(ls[x],tag[x]),put(rs[x],tag[x]),tag[x]=1; } 13 14 void init(int &x,int L,int R,int pos){ 15 x=++nd; s[x]=tag[x]=1; 16 if (L==R) return; 17 int mid=(L+R)>>1; 18 if (pos<=mid) init(ls[x],L,mid,pos); 19 else init(rs[x],mid+1,R,pos); 20 } 21 22 int merge(int x,int y,int sx,int sy){ 23 if (!y){ put(x,sy); return x; } 24 if (!x){ put(y,sx); return y; } 25 push(x); push(y); 26 int x0=s[ls[x]],y0=s[ls[y]],x1=s[rs[x]],y1=s[rs[y]]; 27 ls[x]=merge(ls[x],ls[y],(sx+1ll*(1-p)*x1)%mod,(sy+1ll*(1-p)*y1)%mod); 28 rs[x]=merge(rs[x],rs[y],(sx+1ll*p*x0)%mod,(sy+1ll*p*y0)%mod); 29 s[x]=(s[ls[x]]+s[rs[x]])%mod; return x; 30 } 31 32 int solve(int x){ 33 if (!son[x][0]) { init(rt[x],1,tot,lower_bound(b+1,b+tot+1,w[x])-b); return rt[x]; } 34 int l=solve(son[x][0]); if (!son[x][1]) return l; 35 int r=solve(son[x][1]); p=1ll*g*w[x]%mod; return merge(l,r,0,0); 36 } 37 38 int dfs(int x,int L,int R){ 39 if (L==R) return 1ll*L*b[L]%mod*s[x]%mod*s[x]%mod; 40 int mid=(L+R)>>1; push(x); 41 return (dfs(ls[x],L,mid)+dfs(rs[x],mid+1,R))%mod; 42 } 43 44 int main(){ 45 freopen("a.in","r",stdin); 46 freopen("a.out","w",stdout); 47 scanf("%d",&n); 48 rep(i,1,n) scanf("%d",&fa[i]),son[fa[i]][son[fa[i]][0] ? 1 : 0]=i; 49 rep(i,1,n){ 50 scanf("%d",&w[i]); if (!son[i][0]) b[++tot]=w[i]; 51 } 52 sort(b+1,b+tot+1); printf("%d\n",(dfs(solve(1),1,tot)+mod)%mod); 53 return 0; 54 }