bzoj4820[SDOI2017]硬币游戏
题意
给出n个长度均为m的不同01串,随机生成一个无限长的01串,对n个01串中的每个,求出它最先在随机串中出现的概率.
分析
写这个题的题解比写这个题还难...我可能学了假的概率DP...假装我理解清楚了把坑填了算了....
一眼AC自动机,然后一直在想把AC自动机上的做法优化到线性,GG.
标算的思路是,考虑AC自动机上所有到达终止状态的节点,一共有n个,我们需要求出在这些节点终止的概率,也相当于这些节点期望经过的次数.
设第i个串最先出现的概率为P[i],这也是第i个串的终止节点的期望经过次数.一共有n个变量.
关键的想法是,设另一个变量H表示所有非终止状态节点的期望经过次数.(也就是:随机串中出现了n个串中的一个时,随机串的期望长度).
然后我们可以列出n+1个方程,高斯消元解出n+1个变量.
首先可以列一个方程:P[1]+P[2]+...+P[n]=1.0
然后我们对P[1]到P[n]每个变量列一个方程.
在一个AC自动机上非终止状态的节点,我们有0.5m的概率,使得后面连续的m个字符恰好组成第i个串.第i个串在n个串中首先出现的情况,一定都被包含在这个情况里(必然是从一个非终止状态节点后面接上这个串).看似每个节点都有0.5m的概率走到终止状态,但是某些非终止状态的节点后接上这个串可能会在接上不到m个字符的时候就到达其他串或者这个串本身的终止节点.而且每个非终止状态节点经过的概率不是一样的,我们不能直接求出期望经过多少个节点满足后面接上这个串m个字符后恰好到达这个串的终点.所以我们考虑从n*0.5^m里去除不是加上m个字符后恰好终止的概率.
如果在某个节点加上字符串i的前k个字符后就已经到达了字符串j的终止节点(j可以等于i),那么j的后k个字符必然等于i的前k个字符.在匹配上j后,我们还要继续生成字符使得接下来m-k的字符等于串i的后m-k个字符(我们认为生成的串无限长,不妨认为我们把AC自动机上每个终止节点的出边都指向一个虚拟汇点,虚拟汇点出发的每条边指向自身,那么P[1]...P[n]分别表示经过第i个串终止节点到达汇点的概率(也是第i个终止节点经过次数的期望),这样就可以认为在到达终止节点后继续生成字符了).
任意一种匹配上j的情况必然会匹配上j的后k个字符,因此所有经过第j个终止节点的情况都会在这里出现,但必须保证到达终止节点之后生成的连续m-k个字符和i的后m-k个字符相同,这个概率是0.5(m-k),因此我们对于每个可行的k,P[i]-=P[j]*0.5(m-k).我们枚举i,j,算出两两之间的贡献,可以列出n个方程,所有终止点经过的概率之和为1也是一个方程,然后高斯消元即可.
感觉还是没说明白...有没有更清真的理解方式啊QAQ
#include<cstdio>
#include<cmath>
#include<algorithm>
using namespace std;
char str[305][305];
char S[605];
int fail[605];
void mp(int n){
fail[1]=0;
int j=0;
for(int i=2;i<=n;++i){
while(j&&S[j+1]!=S[i])j=fail[j];
if(S[j+1]==S[i])++j;
fail[i]=j;
}
}
double f[305][305];
int sz;
void Swap(int a,int b){
for(int i=1;i<=sz+1;++i)swap(f[a][i],f[b][i]);
}
void multplus(int a,int b,double t){
for(int i=1;i<=sz+1;++i)f[a][i]+=f[b][i]*t;
}
void gauss(){
for(int i=1;i<=sz;++i){
for(int j=i+1;j<=sz;++j){
if(fabs(f[j][i])>fabs(f[i][i]))Swap(i,j);
}
for(int j=i+1;j<=sz;++j)multplus(j,i,-f[j][i]/f[i][i]);
for(int j=i-1;j>=1;--j)multplus(j,i,-f[j][i]/f[i][i]);
}
for(int j=1;j<=sz;++j)f[j][sz+1]/=f[j][j];
}
int main(){
int n,m;scanf("%d%d",&n,&m);
for(int i=1;i<=n;++i){
scanf("%s",str[i]+1);
}
for(int i=1;i<=n;++i){
for(int j=1;j<=n;++j){
for(int k=1;k<=m;++k){
S[k]=str[i][k];S[m+k]=str[j][k];
}
mp(2*m);
for(int t=fail[2*m];t;t=fail[t]){
if(t<m){
f[i][j]+=pow(0.5,m-t);
}
}
}
}
for(int i=1;i<=n;++i)f[i][i]+=1.0;
for(int i=1;i<=n;++i)f[i][n+1]=-pow(0.5,m);
for(int i=1;i<=n;++i)f[n+1][i]=1.0;f[n+1][n+2]=1.0;
sz=n+1;
// for(int i=1;i<=sz;++i){
// for(int j=1;j<=sz+1;++j){
// printf("%.3f ",f[i][j]);
// }printf("\n");
// }
gauss();
for(int i=1;i<=n;++i)printf("%.10f\n",f[i][n+2]);
return 0;
}