BZOJ 1030 [JSOI2007]文本生成器(ac自动机+dp)
题意
给你n个串,求长度为m的串的方案数,使得这个串至少包含一个这n个任意一个串
思路
转化为求一个也不包含的方案数
\(dp[i][j]\)为第i个字符,匹配到ac自动机上的j号节点的方案数,显然不能匹配到有结束点的地方
fail节点的性质:指向当前串的最长后缀所在的节点
所以当前节点的fail节点为结束点也不行
所有原本没有的空节点在build的时候会指向根节点
所以空节点都在dp[i][0]里
代码
空间开小了一直wa
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<string>
#include<stack>
#include<queue>
#include<deque>
#include<set>
#include<vector>
#include<map>
#define fst first
#define sc second
#define pb push_back
#define mem(a,b) memset(a,b,sizeof(a))
#define lson l,mid,root<<1
#define rson mid+1,r,root<<1|1
#define lc root<<1
#define rc root<<1|1
using namespace std;
typedef double db;
typedef long double ldb;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> PI;
typedef pair<ll,ll> PLL;
const db eps = 1e-6;
const int mod = 10007;
const int maxn = 5e5+100;
const int maxm = 2e6+100;
const int inf = 0x3f3f3f3f;
const db pi = acos(-1.0);
int dp[111][7777];
int n,m;
struct AC{
//局部变量没有默认0!
int trie[maxn][26];
int num[maxn];//单词出现次数
int fail[maxn];
int vis[maxn];//ask函数用到
int tot;
//多测可写个init
void init(){tot=0;mem(vis,-1);mem(trie,0);}
void add(char *s){
int root = 0, len = strlen(s+1);
for(int i = 1; i <= len; i++){
int x = s[i]-'A';
if(!trie[root][x])trie[root][x]=++tot;
root=trie[root][x];
}
num[root]++;
vis[root]=0;
}
void build(){
queue<int>q;
for(int i = 0; i < 26; i++){
if(trie[0][i]){
fail[trie[0][i]]=0;
q.push(trie[0][i]);
}
}
while(!q.empty()){
int now = q.front();
q.pop();
for(int i = 0; i < 26; i++){
if(trie[now][i]){
fail[trie[now][i]]=trie[fail[now]][i];
q.push(trie[now][i]);
}
else trie[now][i]=trie[fail[now]][i];
}
vis[now]&=vis[fail[now]];
//if(vis[fail[now]]==0)vis[now]=0;
}
}
void solve(){
dp[0][0]=1;
for(int i = 0; i < m; i++){
for(int j = 0; j <= tot; j++){
if(vis[j]==0)continue;
for(int k = 0; k < 26; k++){
int to = trie[j][k];
if(vis[to]==0)continue;
(dp[i+1][to]+=dp[i][j])%=mod;
}
}
}
}
}ac;
char a[maxn];
int main(){
ac.init();
scanf("%d %d", &n, &m);
for(int i = 1; i <= n; i++){
scanf("%s",a+1);
ac.add(a);
}
ac.build();
ac.solve();
int ans = 1;
for(int i = 1; i <= m; i++){
ans=ans*26%mod;
}/*
for(int j = 0; j <= ac.tot; j++){
printf("%d %d\n",j,ac.vis[j]);
}
for(int i = 0; i <= ac.tot; i++){
printf("%d:: \n",i);
for(int j = 0; j < 26; j++){
printf("%d %d\n",j,ac.trie[i][j]);
}
}*/
/*for(int i = 1; i <= m; i++){
for(int j = 0; j <= ac.tot; j++){
printf("%d %d == %d\n",i,j,dp[i][j]);
}
}*/
for(int i = 0; i <= ac.tot; i++){
ans=(ans+mod-dp[m][i])%mod;
}
printf("%d",ans);
return 0;
}
/*
10 18
DCSDG
DSSF
SADAV
DSATWYYH
FDHFIS
DFGDFGGD
SASSSS
HHEBB
SFTWRTWW
ZSDFZDS
*/