BZOJ 3530: [Sdoi2014]数数 [AC自动机 数位DP]

3530: [Sdoi2014]数数

题意:\(\le N\)的不含模式串的数字有多少个,\(n=|N| \le 1200\)


考虑数位DP

对于长度\(\le n\)的,普通套路DP\(g[i][j]\)即可

对于长度\(=n\)的,需要考虑天际线,\(f[i][j][0/1]\)表示从高开始i位走到节点j,是否卡上界的方案数

需要注意的是前导0的处理,不能出现前导0,所以\(f[0]\)往外转移的时候不能走0

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
const int N=2005, P=1e9+7;
typedef long long ll;
inline int read(){
    char c=getchar();int x=0,f=1;
    while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}
    return x*f;
}

int n, m;
char a[N], s[N];
inline void mod(int &x) {if(x>=P) x-=P;}
namespace ac{
	struct meow{int ch[10], fail, val;}t[N];
	int sz;
	void insert(char *s) { 
		int len=strlen(s+1), u=0;
		for(int i=1; i<=len; i++) {
			int c=s[i]-'0';
			if(!t[u].ch[c]) t[u].ch[c] = ++sz;
			u=t[u].ch[c];
		}
		t[u].val=1;
	}
	int q[N], head, tail;
	void build() {
		head=tail=1;
		for(int i=0; i<10; i++) if(t[0].ch[i]) q[tail++]=t[0].ch[i];
		while(head!=tail) {
			int u=q[head++];
			t[u].val |= t[t[u].fail].val;
			for(int i=0; i<10; i++) {
				int &v=t[u].ch[i];
				if(!v) v = t[t[u].fail].ch[i];
				else t[v].fail = t[t[u].fail].ch[i], q[tail++]=v;
			}
		}
	}
	int f[N][N][2], g[N][N], ans;
	void dp() {
		g[0][0]=1;
		for(int i=0; i<n; i++) 
			for(int u=0; u<=sz; u++) if(!t[u].val) {
				for(int k=0; k<10; k++) if(!t[t[u].ch[k]].val) { 
					if(i==0 && k==0) continue;
					mod(g[i+1][ t[u].ch[k] ] += g[i][u]);
				}
			}
		for(int i=1; i<n; i++) for(int j=0; j<=sz; j++) mod(ans += g[i][j]);

		f[0][0][1]=1; //f[0][0][0]=1;
		for(int i=0; i<n; i++) {  //printf("\niii %d  %d\n",i, a[i+1]-'0');
			for(int u=0; u<=sz; u++) if(!t[u].val) { //printf("uuu %d  %d %d\n",u,f[i][u][0],f[i][u][1]);
				for(int k=0; k<10; k++) if(!t[t[u].ch[k]].val) { 
					if(i==0 && k==0) continue;
					int v=t[u].ch[k];  //printf("v %d  %d\n",k,v);
					mod(f[i+1][v][0] += f[i][u][0]);
					if(k < a[i+1]-'0') mod(f[i+1][v][0] += f[i][u][1]);
					if(k == a[i+1]-'0') mod(f[i+1][v][1] += f[i][u][1]);
				}
			}
		}
		//for(int i=1; i<=n; i++) for(int j=0; j<=sz; j++) printf("f %d %d  %d %d\n",i,j,f[i][j][0],f[i][j][1]);
		for(int i=0; i<=sz; i++) {
			mod(ans += f[n][i][0]);
			mod(ans += f[n][i][1]);
		}
		printf("%d", ans);
	}
}
int main() {
	freopen("in","r",stdin);
	scanf("%s",a+1); n=strlen(a+1);
	m=read();
	for(int i=1; i<=m; i++) scanf("%s",s+1), ac::insert(s);
	ac::build();
	ac::dp();
}

posted @ 2017-04-04 21:44  Candy?  阅读(652)  评论(0编辑  收藏  举报