P10958 启示录 解题报告

更好的阅读体验

用记忆化搜索写数位 dp 真的很好写!

题目传送门

题目大意:

\(T\) 组数据,每次询问第 \(x\) 个含有至少 \(3\) 个连续 \(6\) 的数是什么。

思路:

考虑数位 dp。

一般数位 dp 问题有两种常见形式:

  1. 询问 \([l, r]\) 内有多少个符合条件的数;
  2. 询问满足条件的第 \(k\) 大(小)的数是什么。

很显然这道题是第二种形式。

首先问题 \(1\) 很简单,那我们考虑将第二个问题转化成第一个问题来做。

因为答案具有单调性,于是可以二分判定。

每次二分到一个值 \(mid\),计算 \([1, mid]\) 的魔鬼数个数,若大于等于 \(x\),则说明所求在 \(mid\) 左侧,否则在 \(mid\) 右侧。

接着考虑问题 \(1\),这里采用记忆化搜索的方式,注释在代码中。

//pos 记录当前填到了哪一位,cnt 记录当前末尾有几个连续的 6,flag 记录当前数是否满足条件
//limit 记录当前有没有顶上界
//因为这道题有没有前导零无影响,遂不记录
int dfs(int pos, int cnt, bool flag, bool limit) {
    //边界,若填完了就检查一下是否符合条件
    if(pos < 0) return flag;
    //若不顶上界就记忆化,因为顶上界是特殊情况,满足条件的数可能和普通情况不同
    if(!limit && f[pos][cnt][flag] != -1) return f[pos][cnt][flag];
    //看一下当前这位需不需要顶上界,若前面填的数都是贴着上界的,这一位最多只能填到 num[pos],否则不受限
    int mx = (limit ? num[pos] : 9);
    int res = 0;
    //枚举第 pos 位填什么
    for(int i = 0; i <= mx; i++) {
        //处理连续的 6
        int ncnt;
        if(i == 6) ncnt = cnt + 1;
        else ncnt = 0;
        res += dfs(pos - 1, ncnt, flag || (ncnt >= 3), limit && (i == num[pos]));
    }
    //若不顶上界就记忆化
    if(!limit) f[pos][cnt][flag] = res;
    return res;
}

这里我直接把二分值域拉满了,但是实测发现第 \(50000000\) 个魔鬼数只有 \(6668056399\)

时间复杂度为:\(O(N^2MT\log V)\),这里 \(N\) 表示数字位数,\(V\) 表示二分值域,\(M\) 表示每次枚举填的数的个数,可看作 \(10\)

\(\texttt{Code:}\)

#include <vector>
#include <cstring>
#include <iostream>

using namespace std;
typedef long long ll;

const int N = 20;

int T;
int x;
ll f[N][N][2];
vector<int> num;

ll dfs(int pos, int cnt, bool flag, bool limit) {
    if(pos < 0) return flag;
    if(!limit && f[pos][cnt][flag] != -1) return f[pos][cnt][flag];
    int mx = (limit ? num[pos] : 9);
    ll res = 0;
    for(int i = 0; i <= mx; i++) {
        int ncnt;
        if(i == 6) ncnt = cnt + 1;
        else ncnt = 0;
        res += dfs(pos - 1, ncnt, flag || (ncnt >= 3), limit && (i == num[pos]));
    }
    if(!limit) f[pos][cnt][flag] = res;
    return res;
}

ll calc(ll x) {
    num.clear();
    ll tmp = x;
    while(tmp) {
        num.push_back(tmp % 10);
        tmp /= 10;
    }
    return dfs(num.size() - 1, 0, 0, 1);
}

void solve() {
    scanf("%d", &x);
    ll l = 1, r = 5e18;
    while(l < r) {
        ll mid = l + r >> 1;
        if(calc(mid) >= x) r = mid;
        else l = mid + 1;
    }
    printf("%lld\n", l);
}

int main() {
    scanf("%d", &T);
    memset(f, -1, sizeof f);
    while(T--) {
        solve();
    }
    return 0;
}
posted @ 2024-09-01 11:39  Brilliant11001  阅读(4)  评论(0编辑  收藏  举报