bzoj3530
AC自动机+数位dp
先建立AC自动机trie图,然后在上面跑dp。
dp[i][j][0]表示到了天际线,也就是j==n时<n的方案数
dp[i][j][1]表示到了天际线并且==n的方案数
f[i][j]表示j<n时合法的方案数
因为如果位数比n小,那么我们不用在乎每位选什么,只要没有前导0就行了,因为位数小肯定比原数小,那么我们就dp一下,计算不经过危险节点的方案数
但是如果到了天际线,也就是j==n,那么这时候要小心统计,具体转移看代码,就是分这一位比原数对应位的大小,然后所有的dp[i][j][0/1],f[i][j]加起来就是答案
对于前导零的处理,就是dp[i][0][0/1]=f[i][1]=0,i代表第一个0节点,这样就不会统计前导0了 dp数组只有在天际线的时候才统计
1A~
#include<bits/stdc++.h> using namespace std; const int N = 1510, mod = 1000000007; int m, n; char s[N], num[N]; struct ac_automation { int cnt, root, pos; int ch[N][11], fail[N], danger[N]; long long dp[N][N][2], f[N][N]; void insert(char s[]) { int now = root, len = strlen(s); for(int i = 0; i < len; ++i) { int &v = ch[now][s[i] - '0']; if(v == 0) v = ++cnt; now = v; } danger[now] = 1; } void construct_fail() { queue<int> q; for(int i = 0; i < 10; ++i) if(ch[root][i] == 0) { ++cnt; if(i == 0) pos = cnt; ch[root][i] = cnt; } for(int i = 0; i < 10; ++i) q.push(ch[root][i]); while(!q.empty()) { int u = q.front(); q.pop(); for(int i = 0; i < 10; ++i) { int &v = ch[u][i]; if(v == 0) v = ch[fail[u]][i]; else { fail[v] = ch[fail[u]][i]; danger[v] |= danger[fail[v]]; q.push(v); } } } } void solve() { for(int i = 1; i < 10; ++i) { int x = ch[root][i]; if(danger[x]) continue; if(i < num[1] - '0') dp[x][1][0] = 1; else if(i == num[1] - '0') dp[x][1][1] = 1; f[x][1] = 1; } long long ans = 0; for(int j = 1; j < n; ++j) for(int i = 1; i <= cnt; ++i) if(!danger[i]) for(int k = 0; k < 10; ++k) { int x = ch[i][k]; if(danger[x]) continue; if(k < num[j + 1] - '0') dp[x][j + 1][0] = (dp[x][j + 1][0] + dp[i][j][0] + dp[i][j][1]) % mod; else if(k == num[j + 1] - '0') { dp[x][j + 1][0] = (dp[x][j + 1][0] + dp[i][j][0]) % mod; dp[x][j + 1][1] = (dp[x][j + 1][1] + dp[i][j][1]) % mod; } else dp[x][j + 1][0] = (dp[x][j + 1][0] + dp[i][j][0]) % mod; if(j + 1 < n) f[x][j + 1] = (f[x][j + 1] + f[i][j]) % mod; } for(int i = 1; i <= cnt; ++i) for(int j = 1; j <= n; ++j) { if((i == pos && j == 1) || danger[i]) continue; ans = (ans + f[i][j]) % mod; if(j == n) ans = (ans + dp[i][j][0] + dp[i][j][1]) % mod; } printf("%lld\n", ans); } } ac; int main() { scanf("%s%d", num + 1, &m); n = strlen(num + 1); for(int i = 1; i <= m; ++i) scanf("%s", s), ac.insert(s); ac.construct_fail(); ac.solve(); return 0; }