很明显,我们很难直接求出“包含长度大于等于2的回文串”的字符的个数,但是我们却可以较为容易的求出“不包含任何长度大于等于2的回文串”的字符的个数,那么我们不如采用正难则反的策略,用总的减去不合法的,那么得到的就是合法的串的个数了。
1 #include <iostream> 2 #include <cstdio> 3 #include <cmath> 4 #include <string> 5 #include <cstring> 6 #include <algorithm> 7 #include <limits> 8 #include <vector> 9 #include <stack> 10 #include <queue> 11 #include <set> 12 #include <map> 13 #include <bitset> 14 #include <unordered_map> 15 #include <unordered_set> 16 #define lowbit(x) ( x&(-x) ) 17 #define pi 3.141592653589793 18 #define e 2.718281828459045 19 #define INF 0x3f3f3f3f 20 #define HalF (l + r)>>1 21 #define lsn rt<<1 22 #define rsn rt<<1|1 23 #define Lson lsn, l, mid 24 #define Rson rsn, mid+1, r 25 #define QL Lson, ql, qr 26 #define QR Rson, ql, qr 27 #define myself rt, l, r 28 #define pii pair<int, int> 29 #define MP(a, b) make_pair(a, b) 30 using namespace std; 31 typedef unsigned long long ull; 32 typedef unsigned int uit; 33 typedef long long ll; 34 const int maxN = 1e3 + 7; 35 const ll mod = 1e9 + 7; 36 char l[maxN], r[maxN]; 37 int dig[maxN]; 38 void MOD(ll &x) { x >= mod ? x %= mod : x; } 39 ll dp[maxN][10][10]; 40 ll dfs(int pos, int x, int lx, bool top, bool zero) 41 { 42 if(pos == 1) return 1; 43 if(!top && (~dp[pos][lx][x])) return dp[pos][lx][x]; 44 ll sum = 0; 45 int u = top ? dig[pos - 1] : 9; 46 for(int i = 0; i <= u; i ++) 47 { 48 if(i) 49 { 50 if(i == x) continue; 51 if(i == lx) continue; 52 } 53 else 54 { 55 if(!zero && i == lx) continue; 56 if(!zero && i == x) continue; 57 } 58 sum += dfs(pos - 1, i, x, top && (i == u), zero && (!x)); 59 MOD(sum); 60 } 61 if(!top && !zero) dp[pos][lx][x] = sum; 62 return sum; 63 } 64 ll solve(char *s) 65 { 66 ll ans = 0; 67 memset(dig, 0, sizeof(dig)); 68 int len = (int)strlen(s); 69 for(int i = 0; i < len; i ++) dig[len - i] = s[i] - '0'; 70 ll all = 0; 71 for(int i = len; i >= 1; i --) 72 { 73 all = all * 10 + dig[i]; 74 MOD(all); 75 } 76 memset(dp, -1, sizeof(dp)); 77 ans = dfs(1002, 0, 0, true, true); 78 ans = all - ans + mod; MOD(ans); 79 return ans; 80 } 81 int main() 82 { 83 scanf("%s%s", l, r); 84 bool zero = true; 85 int len = (int)strlen(l); 86 for(int i = 0; zero && i < len; i ++) if(l[i] ^ '0') zero = false; 87 ll x, y; 88 if(!zero) 89 { 90 l[len - 1] -= 1; 91 int tmp = len - 1; 92 while(l[tmp] < '0') 93 { 94 l[tmp - 1] --; 95 l[tmp] = '9'; 96 tmp --; 97 } 98 if(!tmp) 99 { 100 while(l[tmp] == '0') tmp ++; 101 for(int i = 0; i + tmp < len; i ++) l[i] = l[i + tmp]; 102 l[len - tmp] = '\0'; 103 if(len == tmp) 104 { 105 l[0] = '0'; 106 l[1] = '\0'; 107 } 108 } 109 x = solve(l); 110 } 111 else x = 0; 112 y = solve(r); 113 printf("%lld\n", (y - x + mod) % mod); 114 return 0; 115 }