bzoj2553: [BeiJing2011]禁忌

传送门:http://www.lydsy.com:808/JudgeOnline/problem.php?id=2553

思路:第一件事当然是建立AC自动机。。。

现在我们建好了AC自动机,那么我们就在AC自动机上走,走到一个终止节点就算我们找到一个禁忌串,然后返回根节点重新匹配。

和bzoj1030类似,考虑DP,设f[i][j]为现在长度为i,走到j号节点的期望。

转移就是枚举下一个字符。下一个字符是终止节点就跳回root下次重新走,ans的期望就可以增加1/字符集大小。


因为长度最大有10^9,显然直接DP会无论空间还是时间都会爆炸。。。

所以用矩阵乘法+快速幂加速转移


现在考虑怎么处理出初始的转移矩阵

先算出a[i][j]表示i一步到j的概率

用bfs就可以实现,如果j是i的儿子,那么a[i][j]+=1/字符集大小

为了方便我们新建一个节点n=cnt(总结点数)+1

每次转移root时也转移到它

那么a[i][n]就是i走一步匹配到禁忌串的概率。

要把所有步都累加出来,把a[n][n]赋为1就可以了

因为这样下一次计算时b[root][n]=....+b[root][n]*a[n][n]+....

就可以把上次的答案都累加起来了。

自乘x次后,因为贡献永远是1,所以a[root][n]就表示root走x步遇到禁忌串的期望,也就是答案。


最后吐槽一句:卡精度简直丧心病狂...不开long double就不让过...


#include<cstdio>
#include<cstring>
#include<algorithm>
const int maxn=110;
using namespace std;
struct matrix{
	long double a[maxn][maxn];
	void clear(){for (int i=0;i<maxn;i++) for (int j=0;j<maxn;j++) a[i][j]=0.0;}
}ans,f;
int n,K,dsiz,num;char s[maxn];bool bo[maxn];
inline matrix operator*(matrix a,matrix b){
	matrix res;res.clear();
	for (int i=0;i<=n;i++)
		for (int j=0;j<=n;j++)
			for (int k=0;k<=n;k++)
				res.a[i][j]+=a.a[i][k]*b.a[k][j];
	return res;
}
void qpow(){for (;K;K>>=1,f=f*f) if (K&1) ans=ans*f;}

struct AC_DFA{
	int tot,ch[maxn][26],fail[maxn],q[maxn],head,tail;bool end[maxn];
	void insert(){
		int len=strlen(s),p=0;
		for (int i=0;i<len;p=ch[p][s[i]-'a'],i++) if (!ch[p][s[i]-'a']) ch[p][s[i]-'a']=++tot;
		end[p]=1;
	}
	void getfail(){
		head=0,q[tail=1]=0,fail[0]=-1;
		while (head!=tail){
			int x=q[++head];
			for (int i=0;i<dsiz;i++)
				if (ch[x][i]) q[++tail]=ch[x][i],fail[ch[x][i]]=x==0?0:ch[fail[x]][i];
				else ch[x][i]=x==0?0:ch[fail[x]][i];
			end[x]|=end[fail[x]];
		}
	}
	void build(){
		head=0,q[tail=1]=0,bo[0]=1;
		long double tmp=1.0/dsiz;
		while (head!=tail){
			int x=q[++head];
			for (int i=0;i<dsiz;i++){
				if (!bo[ch[x][i]]) bo[ch[x][i]]=1,q[++tail]=ch[x][i];
				if (end[ch[x][i]]) f.a[x][n]+=tmp,f.a[x][0]+=tmp;
				else f.a[x][ch[x][i]]+=tmp;
			}
		}
	}
}T;

int main(){
	scanf("%d%d%d",&num,&K,&dsiz);
	for (int i=1;i<=num;i++) scanf("%s",s),T.insert();
	n=T.tot+1,T.getfail(),T.build();
	for (int i=0;i<=n;i++) ans.a[i][i]=1;
	f.a[n][n]=1;qpow();
	printf("%.7f\n",(double)ans.a[0][n]);
	return 0;
}



posted @ 2015-07-27 17:23  orzpps  阅读(159)  评论(0编辑  收藏  举报