洛谷P3413 SAC#1 - 萌数 题解 数位DP

题目链接:https://www.luogu.com.cn/problem/P3413

题目大意:
定义萌数指:满足“存在长度至少为2的回文子串”的数。
求区间 \([L,R]\) 范围内萌数的数量。

解题思路:
使用 数位DP 进行求解。
定义状态 \(f[pos][p1][p2]\) 表示满足如下条件时的方案数:

  • 当期数位在第 \(pos\) 位;
  • 前面那个数的前面那个数是 \(p1\)
  • 前面那个数是 \(p2\)

则可以开函数 dfs(int pos, int p1, int p2, bool limit) 进行求解,其中:

  • \(pos,p1,p2\) 的含义同上;
  • \(limit\) 表示当前是否处于限制状态。

注意:数的位数是1000位,所以一开始的输入得用字符串输入,然后再转换。

实现代码如下:

#include <bits/stdc++.h>
using namespace std;
const long long MOD = 1000000007;
long long f[1010][10][10], pow10[1010];
int a[1010];
char ch[1010];
void init() {
    memset(f, -1, sizeof(f));
    pow10[0] = 1;
    for (int i = 1; i <= 1000; i ++) pow10[i] = pow10[i-1] * 10 % MOD;
}
long long dfs(int pos, int p1, int p2, bool limit) {
    if (pos < 0) return 0;  // 因为我一旦找到回文子串会返回,所有到pos<0时还没有找到就直接返回0了
    if (!limit && p1!=-1 && p2!=-1 && f[pos][p1][p2] != -1) return f[pos][p1][p2];
    int up = limit ? a[pos] : 9;
    long long tmp = 0;
    for (int i = 0; i <= up; i ++) {
        if (p1 == i || p2 == i) {
            if (limit && i==up) {
                // tmp += num % pow10[pos] + 1; // 不能这么算,因为是大数
                long long t = 0;
                for (int j = pos-1; j >= 0; j --)
                    t = (t * 10 + a[j]) % MOD;
                tmp += t + 1;
            }
            else tmp += pow10[pos] % MOD;
        }
        else
            tmp += dfs(pos-1, p2, (p2==-1&&i==0&&pos>0)?-1:i, limit && i==up);
        tmp %= MOD;
    }
    if (!limit && p1!=-1 && p2!=-1) f[pos][p1][p2] = tmp;
    // printf("dfs pos=%d, p1=%d, p2=%d, limit=%d, tmp = %lld\n", pos, p1, p2, limit, tmp);
    return tmp;
}
long long get_num(bool minus1) {
    cin >> ch;
    int len = strlen(ch);
    for (int i = 0; i < len; i ++) a[i] = ch[len-1-i] - '0';
    // 判断是否为0
    bool all0 = true;
    for (int i = 0; i < len; i ++) if (a[i] != 0) { all0 = false; break; }
    if (all0) return 0;
    // 判断是否要减1
    if (minus1) {
        a[0] --;
        for (int i = 0; i < len; i ++) {
            if (a[i] < 0) { a[i] += 10; a[i+1] --; }
            else break;
        }
    }
    return dfs(len-1, -1, -1, true);
}
int main() {
    init();
    long long num_l = get_num(true);
    long long num_r = get_num(false);
    cout << (num_r - num_l + MOD) % MOD << endl;
    return 0;
}
posted @ 2019-12-07 16:40  quanjun  阅读(180)  评论(0编辑  收藏  举报