[CTSC2017]游戏(Bayes定理,线段树)

传送门:http://uoj.ac/problem/299

题目良心给了Bayes定理,但对于我这种数学渣来说并没有什么用。

先大概讲下相关数学内容:

1.定义:$P(X)$ 表示事件$X$发生的概率,$E(X)$表示随机变量$X$的期望值,$P(A|B)$表示已知$B$发生,$A$发生的概率,$P(AB)$表示$A$和$B$同时发生的概率。

2.条件概率公式:

$\begin{aligned}P(A|B)=\frac{P(AB)}{P(B)}\end{aligned}$。

由$P(B)P(A|B)=P(AB)$移项可得。

3.全概率公式:

当存在$k$个互斥事件且事件并集为全集时:

$P(A)=\sum\limits_{i=1}^{k}P(A|X=X_i)P(X=Xi)$

由条件概率公式可得。

4.Bayes定理:

$\begin{aligned}P(A|B)=\frac{P(AB)}{P(B)}=\frac{P(B|A)P(A)}{P(B)}\end{aligned}$

由条件概率公式可得。

 

接下来是OI部分。

对于每个位置,找到夹着它的两个已知胜负条件,设为A和B。

将A作为必要条件(题设),可知$E(C)=\sum_{i=1}^{k}P(c_i=1)$。根据Bayes公式可得$\begin{aligned}P(c_i=1|B)=\frac{P(c_i=1)P(B|c_i=1)}{P(B)}=\frac{P(c_i=1,B)}{P(B)}\end{aligned}$

修改的时候使用线段树进行区间合并,用矩阵加速合并。

设节点$x$代表的区间是$[L,R]$,则val[x][0/1][0/1]表示$L-1$为胜/负时,R为胜/负的$P(c_L=1)*P(B|c_L=1)$和$P(B)$。

 

接下来是暴力部分(10分)。

指数级枚举所有状态,计算这个状态出现的概率以及是否合法,累加即可。

 

方法一:暴力

 1 #include<cstdio>
 2 #include<cstring>
 3 #include<algorithm>
 4 using namespace std;
 5 
 6 const int N=200100;
 7 char S[15];
 8 int n,m;
 9 double p[N],q[N],w[15];
10 
11 void calc(){
12     double ans=0,tp=0;
13     int up=(1<<n);
14     for(int i=0;i<up;i++){
15         bool f=1; int c=0,las=1; double pc=1;
16         for(int j=0,k=1;j<n;j++,k++)
17             if(i&(1<<j)){
18                 if(w[k]==0){ f=false; break; }
19                 c++;
20                 if (las) pc*=p[k]; else pc*=q[k];
21                 las=1;
22             }else{
23                 if(w[k]==1){ f=false; break; }
24                 if(las)pc*=(1-p[k]); else pc*=(1-q[k]);
25                 las=0;
26             }
27         if(f)ans+=pc*c,tp+=pc;
28     }
29     ans/=tp; printf("%.6f\n",ans);
30 }
31 
32 void solve(){
33     for(int i=1;i<=n;i++)w[i]=-1;
34     for(int i=1,x,y;i<=m;i++){
35         scanf("%s",S);
36         if(S[0]=='a') scanf("%d%d",&x,&y),w[x]=y; else scanf("%d",&x),w[x]=-1;
37         calc();
38     }
39 }
40 
41 int main(){
42     freopen("game.in","r",stdin);
43     freopen("game.out","w",stdout);
44     scanf("%d%d",&n,&m); scanf("%s",S);
45     scanf("%lf",&p[1]);
46     for(int i=2;i<=n;i++)scanf("%lf %lf",&p[i],&q[i]);
47     solve();
48     return 0;
49 }

 

 

方法二:正解

 1 #include<set>
 2 #include<cstdio>
 3 #include<cstring>
 4 #include<algorithm>
 5 #define ls (x<<1)
 6 #define rs (ls|1)
 7 #define rep(i,l,r) for (int i=(l); i<=(r); i++)
 8 using namespace std;
 9 
10 const int N=200100;
11 int n,m,p,c,type,S2[N];
12 double ans,P[N],Q[N];
13 set<int>S1;
14 char op[10];
15 
16 struct M{ double v[2][2]; M(){ memset(v,0,sizeof(v)); }; };
17 M operator +(M a,const M &b){ rep(i,0,1) rep(j,0,1) a.v[i][j]+=b.v[i][j]; return a; }
18 M operator *(const M &a,const M &b){ M c; rep(i,0,1) rep(j,0,1) rep(k,0,1) c.v[i][j]+=a.v[i][k]*b.v[k][j]; return c; }
19 
20 struct inf{ M f,g; inf(const M &_f=M(),const M &_g=M()):f(_f),g(_g){}; };
21 inf operator +(const inf &a,const inf &b){ return inf(a.f*b.f,a.g*b.f+a.f*b.g); }
22 struct node{ int l,r,mid; inf val; }seg[N<<2];
23 
24 void build(int x,int l,int r){
25     if (l==r){
26         seg[x].val.f.v[1][1]=P[l]; seg[x].val.f.v[1][0]=1.-P[l];
27         seg[x].val.f.v[0][1]=Q[l]; seg[x].val.f.v[0][0]=1.-Q[l];
28         seg[x].val.g.v[1][1]=P[l]; seg[x].val.g.v[0][1]=Q[l];
29         return;
30     }
31     int mid=(l+r)>>1;
32     build(ls,l,mid); build(rs,mid+1,r); seg[x].val=seg[ls].val+seg[rs].val;
33 }
34 
35 inf que(int x,int L,int R,int l,int r){
36     if (L==l && r==R) return seg[x].val;
37     int mid=(L+R)>>1;
38     if (r<=mid) return que(ls,L,mid,l,r);
39     else if (l>mid) return que(rs,mid+1,R,l,r);
40         else return que(ls,L,mid,l,mid)+que(rs,mid+1,R,mid+1,r);
41 }
42 
43 double ask(int l,int r){ inf v=que(1,0,n+1,l+1,r); return v.g.v[S2[l]][S2[r]]/v.f.v[S2[l]][S2[r]]; }
44 
45 int main(){
46     freopen("game.in","r",stdin);
47     freopen("game.out","w",stdout);
48     scanf("%d%d",&n,&m); scanf("%s",op); scanf("%lf",&P[1]);
49     rep(i,2,n) scanf("%lf%lf",&P[i],&Q[i]);
50     S1.insert(0); S2[0]=1; S1.insert(n+1); S2[n+1]=0; P[n+1]=Q[n+1]=0.;
51     build(1,0,n+1); ans=ask(0,n+1);
52     while (m--){
53         scanf("%s",op);
54         if (*op=='a'){
55             scanf("%d%d",&p,&c); set<int>::iterator nxt=S1.lower_bound(p),lst=nxt; lst--;
56             S2[p]=c; ans-=ask(*lst,*nxt); ans+=ask(*lst,p)+ask(p,*nxt); S1.insert(p);
57         }else{
58             scanf("%d",&p); set<int>::iterator mid=S1.find(p),lst,nxt;
59             lst=nxt=mid; lst--; nxt++; ans-=ask(*lst,p)+ask(p,*nxt); ans+=ask(*lst,*nxt); S1.erase(p);
60         }
61         printf("%.6lf\n",ans);
62     }
63     return 0;
64 }

 

posted @ 2018-04-18 19:22  HocRiser  阅读(280)  评论(0编辑  收藏  举报