UOJ#299. 【CTSC2017】游戏 线段树 概率期望 矩阵
原文链接www.cnblogs.com/zhouzhendong/p/UOJ299.html
前言
不会概率题的菜鸡博主做了一道概率题。
写完发现运行效率榜上的人都没有用心卡常数——矩阵怎么可以用数组呢?矩乘怎么可以用循环呢?
截止2019-05-15暂居运行效率榜一。
题解
首先,根据期望的线性性,容易得知,总期望等于以已知点为界的各个未知段的期望之和加上已知点的和。易知每段区间的期望只和自身转移系数和这段区间两端的已知点信息有关。
考虑到每次加入和删除信息时,只会影响 $O(1)$ 段区间的两端节点。
形式化地,我们设 $R_i$ 表示事件 “$R$ 在第 $i$ 局中胜出”, $B_i$ 表示事件 “$B$ 在第 $i$ 局中胜出”。
题意中提到的获胜概率可以表示为
$$P(R_i|R_{i-1}) = p_i,P(B_i|R_{i-1}) = 1-p_i\\P(R_i|B_{i-1}) = q_i,P(B_i|B_{i-1}) = 1-q_i$$
设行向量 $L_i = [P(R_i),P(B_i),E[R_i],E[B_i]]$,其中 $E[R_i],E[B_i]$ 到第 $i$ 局 $R$ 获胜和 $B$ 获胜时,$R$ 获胜局数的期望。
建立概率期望转移矩阵 $M_i$,使得 $L_i M_i = L_{i+1}$。容易得到:
$$M_i = \begin{bmatrix}p_i& 1-p_i& p_i & 0\\q_i& 1-q_i& q_i& 0\\0& 0& p_i&1-p_i\\0 &0 &q_i &1-q_i\end{bmatrix}$$
假设我们已经推得了某个区间的最后一个位置的概率行向量。接下来我们还要加上右侧已知信息对概率期望的影响。
我们直接求得 $L_{i+1}$,根据条件概率的计算公式,可以直接计算答案。
为了方便,我们可以设 $P(R_0) = 0, P(B_0) = 1$。
由于本题涉及 double 类型的精度问题,所以对矩阵求逆会导致过大的精度误差,所以只能使用线段树来得到区间矩阵积。
每次在修改操作的时候重算 $O(1)$ 个区间对答案的贡献即可。
时间复杂度 $O(m\log n)$ 。
代码
#include <bits/stdc++.h> #define clr(x) memset(x,0,sizeof x) #define For(i,a,b) for (int i=a;i<=b;i++) #define Fod(i,b,a) for (int i=b;i>=a;i--) #define fi first #define se second #define pb(x) push_back(x) #define mp(x,y) make_pair(x,y) #define outval(x) printf(#x" = %d\n",x) #define outtag(x) puts("---------------"#x"---------------") #define outarr(a,L,R) printf(#a"[%d..%d] = ",L,R);\ For(_x,L,R)printf("%d ",a[_x]);puts("") using namespace std; typedef long long LL; namespace IO{ const int S=1<<20; char I[S+1],*Is=I,*It=I,O[S+1],*Ot=O; char gc(){return Is==It?((It=(Is=I)+fread(I,1,S,stdin))==I?EOF:*Is++):*Is++;} void flush(){fwrite(O,1,Ot-O,stdout),Ot=O;} void pc(char ch){Ot==O+S?flush(),*Ot++=ch:*Ot++=ch;} struct flusher{ ~flusher(){flush();}}Flusher; #define getchar gc #define putchar pc } using IO::gc; using IO::pc; LL read(){ LL x=0,f=0; char ch=getchar(); while (!isdigit(ch)) f|=ch=='-',ch=getchar(); while (isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); return f?-x:x; } const int N=200005; struct Mat{ double v00,v01,v02,v03; double v10,v11,v12,v13; #define v22 v00 #define v23 v01 #define v32 v10 #define v33 v11 Mat(){} Mat(double x){ v00=v01=v02=v03=v10=v11=v12=v13=0; v00=v11=x; } Mat(double p,double q){ v00=v01=v02=v03=v10=v11=v12=v13=0; v00=p,v01=1-p; v10=q,v11=1-q; v02=p,v03=0; v12=q,v13=0; } friend Mat operator * (Mat A,Mat B){ Mat C(0); C.v00=A.v00*B.v00+A.v01*B.v10; C.v01=A.v00*B.v01+A.v01*B.v11; C.v10=A.v10*B.v00+A.v11*B.v10; C.v11=A.v10*B.v01+A.v11*B.v11; C.v02=A.v00*B.v02+A.v01*B.v12+A.v02*B.v22+A.v03*B.v32; C.v03=A.v00*B.v03+A.v01*B.v13+A.v02*B.v23+A.v03*B.v33; C.v12=A.v10*B.v02+A.v11*B.v12+A.v12*B.v22+A.v13*B.v32; C.v13=A.v10*B.v03+A.v11*B.v13+A.v12*B.v23+A.v13*B.v33; return C; } }M[N],prod[N<<2]; int n,m; char type[233]; double p[N],q[N],rec[N]; int s[N];// 0 -> Unknown, 1 -> R, 2 -> B void Build(int rt,int L,int R){ if (L==R){ prod[rt]=M[L]; return; } int mid=(L+R)>>1,ls=rt<<1,rs=ls|1; Build(ls,L,mid); Build(rs,mid+1,R); prod[rt]=prod[ls]*prod[rs]; } Mat mres; void Query(int rt,int L,int R,int xL,int xR){ if (xL>xR||R<xL||L>xR) return; if (xL<=L&&R<=xR) return (void)(mres=mres*prod[rt]); int mid=(L+R)>>1,ls=rt<<1,rs=ls|1; Query(ls,L,mid,xL,xR); Query(rs,mid+1,R,xL,xR); } set <int> S; double getE(int L,int R){ Mat Li(0); mres=Mat(1); if (s[L-1]==1) Li.v00=1; else Li.v01=1; Query(1,1,n,L,R); Li=Li*mres; if (R==n) return rec[L]=Li.v02+Li.v03; Li=Li*M[R+1]; if (s[R+1]==1) return rec[L]=Li.v02/Li.v00-1; else return rec[L]=Li.v03/Li.v01; } double readfloat(){ double x=0,w=1; char ch=getchar(); while (!isdigit(ch)) ch=getchar(); while (isdigit(ch)) x=x*10+ch-48,ch=getchar(); if (ch=='.'){ ch=getchar(); while (isdigit(ch)) w/=10,x+=w*(ch-48),ch=getchar(); } return x; } void outint(int x){ if (x>9) outint(x/10); putchar('0'+x%10); } void outfloat(double x){ outint((int)x); x-=(int)x; putchar('.'); For(i,1,5) x*=10,putchar('0'+(int)x),x-=(int)x; } void readstr(char *s){ char ch=getchar(); while (isspace(ch)) ch=getchar(); while (!isspace(ch)) *s++=ch,ch=getchar(); } int main(){ n=read(),m=read(); readstr(type); p[1]=readfloat(),q[1]=0; clr(s),s[0]=1; For(i,2,n) p[i]=readfloat(),q[i]=readfloat(); For(i,1,n) M[i]=Mat(p[i],q[i]); Build(1,1,n); S.clear(),S.insert(0),S.insert(n+1); double now=getE(1,n); while (m--){ readstr(type); int x=read(); if (type[0]=='a'){ int c=read(); c=(c^1)+1; if (c==1) now+=1; s[x]=c; set <int> :: iterator it=S.lower_bound(x); int rp=*it,lp=*--it; S.insert(x); now-=rec[lp+1]; now+=getE(lp+1,x-1); now+=getE(x+1,rp-1); } else { if (s[x]==1) now-=1; S.erase(x); set <int> :: iterator it=S.lower_bound(x); int rp=*it,lp=*--it; now-=rec[lp+1]; now-=rec[x+1]; now+=getE(lp+1,rp-1); s[x]=0; } outfloat(now),putchar('\n'); } return 0; }