BZOJ1559[JSOI2009]密码——AC自动机+DP+搜索
题目描述
输入
输出
样例输入
10 2
hello
world
hello
world
样例输出
2
helloworld
worldhello
helloworld
worldhello
提示
这题算是一个套路题了,多个串求都包含它们的长为L的串的方案数。
显然是一个在AC自动机(trie图)上DP,常规DP状态是f[i][j]表示在AC自动机上走了i步到达了j节点的方案数。
但这道题还要求包含所有模式串,而且模式串最多10个,因此再加一维f[i][j][k]表示在AC自动机上走了i步到达了j节点,已经包含的字符串状态为k的方案数,其中k是一个二进制状态。
但我们发现如果一个串x是另一个串y的子串,那么只要包含y就一定包含x,因此在DP之前还要去掉被包含的串。
我去掉被包含串的方法是当一个终止节点有子节点(在找fail指针之前)或者一个终止节点被其他点通过fail指针指向(在找fail指针之后),那么说明这个串被包含,就将他的终止标记删掉。
剩下还有输出方案,因为只在方案数<=42时输出,所以方案一定是由模式串组成并且相邻模式串首尾重复部分一定要去重。
为什么?
因为假如有一个随机字符,只有一个模式串,那么他们的方案数就是2*26=52>42,所以一定不包含随机字符。
而如果不将相邻模式串去重就能到达长度为L,那么去重之后就会出现随机字符,方案数还是会超过42。
综上所述,密码串就是由所有模式串(不包括是其他串子串的串)的排列组成,最多就10个串,预处理出任意两个模式串的重叠长度,爆搜一下就好了。
#include<set> #include<map> #include<stack> #include<queue> #include<cmath> #include<vector> #include<cstdio> #include<cstring> #include<iostream> #include<algorithm> using namespace std; int s[120][30]; int fail[120]; int num[120]; long long f[3][120][1025]; int n,L,m; int cnt; char ch[30][30];; int vis[120]; long long ans; char res[50][30]; int lk[30][30]; int q[30]; int tot; int v[30]; int rank[30]; int que[30]; void build(char *ch,int k) { int len=strlen(ch); int now=0; for(int i=0;i<len;i++) { int x=ch[i]-'a'; if(!s[now][x]) { s[now][x]=++cnt; } now=s[now][x]; } vis[now]=k; } void get_fail() { queue<int>q; for(int i=0;i<26;i++) { if(s[0][i]) { q.push(s[0][i]); fail[s[0][i]]=0; } } while(!q.empty()) { int now=q.front(); q.pop(); for(int i=0;i<26;i++) { if(s[now][i]) { fail[s[now][i]]=s[fail[now]][i]; q.push(s[now][i]); } else { s[now][i]=s[fail[now]][i]; } } } } void find_end() { for(int i=1;i<=cnt;i++) { if(vis[i]) { for(int j=0;j<26;j++) { if(s[i][j]) { vis[i]=0; break; } } } } get_fail(); for(int i=1;i<=cnt;i++) { if(vis[fail[i]]) { vis[fail[i]]=0; } } for(int i=1;i<=cnt;i++) { if(vis[i]) { m++; q[m]=vis[i]; num[i]=1<<(m-1); } } } void dp() { f[0][0][0]=1; for(int i=0;i<L;i++) { memset(f[(i+1)&1],0,sizeof(f[(i+1)&1])); for(int j=0;j<=cnt;j++) { for(int k=0;k<=(1<<m)-1;k++) { if(f[i&1][j][k]) { for(int l=0;l<26;l++) { int x=s[j][l]; f[(i+1)&1][x][k|num[x]]+=f[i&1][j][k]; } } } } } for(int i=0;i<=cnt;i++) { ans+=f[L&1][i][(1<<m)-1]; } } int get_lk(int x,int y) { int i,j; bool flag; int lx=strlen(ch[x]); int ly=strlen(ch[y]); for(i=min(lx,ly);i>0;i--) { flag=1; for(j=0;j<i;j++) { if(ch[x][lx-i+j]!=ch[y][j]) { flag=0; break; } } if(flag) { break; } } return i; } void dfs(int dep) { if(dep>m) { tot++; int l=0; for(int i=1;i<dep;i++) { int len=strlen(ch[que[i]]); for(int j=lk[que[i-1]][que[i]];j<len;j++) { res[tot][l]=ch[que[i]][j]; l++; } } if(l!=L) { tot--; } return ; } for(int i=1;i<=m;i++) { if(!v[i]) { v[i]=1; que[dep]=q[i]; dfs(dep+1); v[i]=0; } } } int cmp(int x,int y) { for(int i=0;i<L;i++) { if(res[x][i]!=res[y][i]) { return res[x][i]<res[y][i]; } } return 0; } int main() { scanf("%d%d",&L,&n); for(int i=1;i<=n;i++) { scanf("%s",ch[i]); build(ch[i],i); } find_end(); dp(); printf("%lld\n",ans); for(int i=1;i<=m;i++) { for(int j=1;j<=m;j++) { lk[q[i]][q[j]]=get_lk(q[i],q[j]); } } if(ans<=42) { dfs(1); for(int i=1;i<=tot;i++) { rank[i]=i; } sort(rank+1,rank+tot+1,cmp); for(int i=1;i<=tot;i++) { for(int j=0;j<L;j++) { printf("%c",res[rank[i]][j]); } printf("\n"); } } }