题解 P3772 【CTSC[2017] 游戏】(概率期望,线段树)

这题还有一个更加朴实无华的维护方法。


对于条件期望,我们有:

\[E(A|B)=\dfrac{E(AB)}{P(B)} \]

(可以这样理解:\(E(A|B)P(B)=E(AB)\),即事件 \(B\) 发生的概率乘在 \(B\) 发生前提下事件 \(A\) 的期望就是 \(A\)\(B\) 同时发生的期望。)


对于这道题,我们发现根据已确定输赢的位置将 \(n\) 局游戏划成若干区间,设,区间 \([l,r]\) 内的输赢期望只和 \(l\)\(r\) 的输赢情况有关,其它的修改影响不到区间 \([l,r]\)

那维护区间信息自然想到线段树,设 \(p(l,r,a,b)\) 为位置 \(l\) 的输赢情况为 \(a\) 且位置 \(r\) 的输赢情况为 \(b\) 的概率,\(e(l,r,a,b)\) 为在位置 \(l\) 的输赢情况为 \(a\) 且位置 \(r\) 的输赢情况为 \(b\) 时区间 \([l,r]\) 赢的局数的期望,\(tp[x][a][b]\) 表示在前一个位置的输赢情况为 \(a\) 时位置 \(x\) 输赢情况为 \(b\) 的概率(\(tp\) 可以在一开始根据 \(p_i\)\(q_i\) 预处理)。

下面我们尝试从 \([l,mid]\)\([mid+1,r]\) 转移到 \([l,r]\)

对于 \(p\) ,我们枚举 \(mid\)\(mid+1\) 的输赢情况:

\[p(l,r,a,b)=\sum_{0\leq c,d\leq 1}p(l,mid,a,c)\times tp(mid+1,c,d)\times p(mid+1,r,d,b) \]

对于 \(e\),根据上面条件期望的公式,记“赢”为事件 \(A\),“位置 \(l\) 的输赢情况为 \(a\) 且位置 \(r\) 的输赢情况为 \(b\)”为事件 \(B\),则有:

\[e(l,r,a,b)=\dfrac{\sum_{0\leq c,d\leq 1}p(l,mid,a,c)\times tp(mid+1,c,d)\times p(mid+1,r,d,b)\times(e(l,mid,a,c)+e(mid+1,r,d,b))}{p(l,r,a,b)} \]

于是就可以在常数时间内转移。

然后每次用 std::set 维护已确定的位置,插入和删除时加加减减即可。(有个细节,当钦定位置 \(i\) 赢的时候,我们令 \(ans\gets ans+e(pre,i,c_{pre},c_i)+e(i,nex,c_i,c_{nex})-e(pre,nex,c_pre,c_nex)\) 时将位置 \(i\) 的期望计算了两次,应减去 \(1\);删除的时候同理)

线段树部分代码(完整代码去提交记录找):

double p[200005],q[200005],tp[200005][2][2];
struct node{
	double p[2][2],e[2][2];
	int l,r;
}t[200005<<2];
inline void merge(node& cx,node cl,node cr){
	cx.l=cl.l,cx.r=cr.r;
	cx.p[0][0]=cx.p[0][1]=cx.p[1][0]=cx.p[1][1]=cx.e[0][0]=cx.e[0][1]=cx.e[1][0]=cx.e[1][1]=0;
	for(rg int a=0;a<2;++a) for(rg int b=0;b<2;++b){
		for(rg int c=0;c<2;++c) for(rg int d=0;d<2;++d) cx.p[a][b]+=cl.p[a][c]*cr.p[d][b]*tp[cr.l][c][d];
		for(rg int c=0;c<2;++c) for(rg int d=0;d<2;++d) cx.e[a][b]+=cl.p[a][c]*cr.p[d][b]*tp[cr.l][c][d]*(cl.e[a][c]+cr.e[d][b]);
        if(fabs(cx.p[a][b])<1e-6) cx.e[a][b]=0;
		else cx.e[a][b]/=cx.p[a][b];
	}
}
void build(int x,int l,int r){
	t[x].l=l,t[x].r=r;
	if(l==r) return t[x].p[0][0]=1,t[x].p[1][1]=1,t[x].e[0][0]=1.0,void();
	const int mid=l+r>>1;
	build(x<<1,l,mid),build(x<<1|1,mid+1,r);
	merge(t[x],t[x<<1],t[x<<1|1]);
}
node query(int x,const int sl,const int sr){
	if(sl<=t[x].l&&sr>=t[x].r) return t[x];
	const int mid=t[x].l+t[x].r>>1;
	if(sl>mid) return query(x<<1|1,sl,sr);
	if(sr<=mid) return query(x<<1,sl,sr);
	node ret; merge(ret,query(x<<1,sl,sr),query(x<<1|1,sl,sr)); return ret;
}
posted @ 2022-02-09 16:45  Legitimity  阅读(94)  评论(0编辑  收藏  举报