uvalive 6343 - Dyslexic Gollum AC自动机
一开始完全看不出这是AC自动机啊。
求长度为T且不包含长度大于K的回文串的二进制串有多少种。
枚举长度为K和K+1的全部回文串,保存在tire里面,然后构造自动机DP
因为长度>k+1的肯定也包含这些长度为K或K+1的。
//#pragma comment(linker, "/STACK:1024000000,1024000000") #include<cstdio> #include<cstring> #include<cstdlib> #include<algorithm> #include<iostream> #include<sstream> #include<cmath> #include<climits> #include<string> #include<map> #include<queue> #include<vector> #include<stack> #include<set> using namespace std; typedef long long ll; typedef unsigned long long ull; typedef pair<int,int> pii; #define pb(a) push_back(a) #define INF 0x1f1f1f1f #define lson idx<<1,l,mid #define rson idx<<1|1,mid+1,r #define PI 3.1415926535898 template<class T> T min(const T& a,const T& b,const T& c) { return min(min(a,b),min(a,c)); } template<class T> T max(const T& a,const T& b,const T& c) { return max(max(a,b),max(a,c)); } void debug() { #ifdef ONLINE_JUDGE #else freopen("d:\\in.txt","r",stdin); // freopen("d:\\out1.txt","w",stdout); #endif } int getch() { int ch; while((ch=getchar())!=EOF) { if(ch!=' '&&ch!='\n')return ch; } return EOF; } const int MAX_NODE=1<<11; const int SIGMA_SIZE=2; const int mod=1000000007; int ch[MAX_NODE][SIGMA_SIZE]; int val[MAX_NODE]; int fail[MAX_NODE]; int sz; int dp[MAX_NODE][444]; void init() { memset(ch[0],0,sizeof(ch[0])); val[0]=0; sz=1; } int idx(char c){return c-'0';} void insert(const char *s) { int u=0; for(int i=0;s[i]!='\0';i++) { int v=idx(s[i]); if(!ch[u][v]) { memset(ch[sz],0,sizeof(ch[sz])); val[sz]=0; ch[u][v]=sz++; } u=ch[u][v]; } val[u]=1; } void construct() { fail[0]=0; queue<int> q; for(int c=0;c<SIGMA_SIZE;c++) { int u=ch[0][c]; if(u){fail[u]=0;q.push(u);} } while(!q.empty()) { int r=q.front(); q.pop(); for(int c=0;c<SIGMA_SIZE;c++) { int u=ch[r][c]; if(!u) { ch[r][c]=ch[fail[r]][c]; continue; } q.push(u); int v=fail[r]; while(v&&!ch[v][c])v=fail[v]; fail[u]=ch[v][c]; val[u]|=val[fail[u]]; } } } int f(int u,int n) { if(n==0)return 1; if(dp[u][n]>=0)return dp[u][n]; int &res=dp[u][n]; res=0; for(int c=0;c<SIGMA_SIZE;c++) { if(!val[ch[u][c]]) res=(res+f(ch[u][c],n-1))%mod; } return res; } int check(int st,int len,char *s) { int i=0; while(i<len) { s[i++]=st%2+'0'; st/=2; } s[len]='\0'; for(int i=0,j=len-1;i<j;i++,j--) { if(s[i]!=s[j])return 0; } return 1; } void prework(int k) { for(int len=k;len<=k+1;len++) { for(int st=0;st<(1<<len);st++) { char s[12]; if(check(st,len,s)) { insert(s); } } } } int main() { int t; scanf("%d",&t); for(int ca=1;ca<=t;ca++) { init(); int n,k; scanf("%d%d",&n,&k); memset(dp,-1,sizeof(dp)); prework(k); construct(); int num=f(0,n); printf("%d\n",num); } return 0; }