「笔记」数位DP
写在前面
19 年前听 zlq 讲课的时候学的东西,当时只会抄板子,现在来重学一波= =
一个板子水一天题(不事
引入
给定参数 \(l,r\),求 \([l,r]\) 中不含前导零且相邻两个数字之差至少为 \(2\) 的正整数的个数。
\(1\le l\le r\le 2\times 10^9\)。
1S,512MB。
这是一个经典的数位 DP 的例子。其模型一般是给定一些对于数的限制条件,求在给定范围内满足限制的数的贡献。
通过数位 DP 一般可以在 \(O(m\log_{10}{(n)})\) 的时间内解决此问题,其中 \(m\) 是数码种类数,\(n\) 是取值的最大值。
求解
首先将询问 \([l,r]\) 内合法的数的个数拆成询问 \([0\sim l-1]\) 和 \([0, r]\) 内合法的数的个数,之后考虑数位 DP。
数位 DP 有递推 和 记忆化搜索两种写法,由于记忆化搜索更容易理解与实现,我们一般采用记忆化搜索解决此类问题。以下也仅介绍记忆化搜索的解法。
先考虑爆搜。考虑枚举所有范围内的数,搜索的同时检查是否满足给定的限制条件。注意考虑前导零与是否达到枚举的上界,其代码如下所示:
int numlth, num[kN]; //储存给定值的从高位到低位的十进制拆分。
//now_:当前填到第几位; last_:now_ - 1 位填的数;
//zero_:前 now_ - 1 位是否均为 0; lim_:前 now_ - 1 位是否达到枚举的上界(与 num 相同)
int Dfs(int now_, int last_, bool zero_, bool lim_) {
if (now_ > numlth) return 1; //当前枚举的数合法
int ret = 0;
//枚举第 now_ 位填的数,up 为该位填数的上界
for (int i = 0, up = lim_ ? num[now_] : 9; i <= up; ++ i) {
if (abs(i - last_) < 2) continue ;
if (zero_ && !i) ret += Dfs(now_ + 1, 11, true, lim_ && i == up); //前 now_ 位均为 0
else ret += Dfs(now_ + 1, i, false, lim_ &&i == up);
}
return ret;
}
//ans[0, x] = Dfs(1, 11, true, true);
发现当枚举的数前缀的性质相同,即 dfs 的四个参数相同时,dfs 的返回值相同。
比如当枚举到 \(020\underline{?}??\) 和 \(010\underline{?}??\) 时,dfs 的参数均为 (4, 0, false, false)
。表示它们前缀的性质相同,枚举之后位数得到的答案显然也相同。
简单记忆化即可避免重复枚举过程。
//f[i][j][0/1][0/1] 表示 dfs(i, j, 0/1, 0/1) 的答案。
int numlth, num[kN], f[kN][kN][2][2];
int Dfs(int now_, int last_, bool zero_, bool lim_) {
if (now_ > numlth) return 1;
if (f[now_][last_][zero_][lim_] != -1) return f[now_][last_][zero_][lim_];
int ret = 0;
for (int i = 0, up = lim_ ? num[now_] : 9; i <= up; ++ i) {
if (abs(i - last_) < 2) continue ;
if (zero_ && !i) ret += Dfs(now_ + 1, 11, true, lim_ && i == up);
else ret += Dfs(now_ + 1, i, false, lim_ && i == up);
}
return f[now_][last_][zero_][lim_] = ret;
}
//ans[0, x] = Dfs(1, 11, true, true);
特判优化
发现上述 dfs 的过程中,\(\operatorname{lim} = 1\) 或 \(\operatorname{zero} = 1\) 的状态只会被枚举到 1 次,即只会重复调用 dfs(now_, last_, 0, 0)
。对这两维的记忆化对减少枚举次数是做负功的。
于是可以通过特判去除这两维,如下所示:
//f[i][j] 表示 dfs(i, j, 0, 0) 的答案。
int Dfs(int now_, int last_, bool zero_, bool lim_) {
if (now_ > numlth) return 1;
if (!lim_ && f[now_][last_] != -1) return f[now_][last_];
int ret = 0;
for (int i = 0, up = lim_ ? num[now_] : 9; i <= up; ++ i) {
if (abs(i - last_) < 2) continue ;
if (zero_ && !i) ret += Dfs(now_ + 1, 11, true, lim_ && i == up);
else ret += Dfs(now_ + 1, i, false, lim_ && i == up);
}
if (!lim_) f[now_][last_] = ret;
return ret;
}
可以感性理解特判的实际意义。若 dfs 的参数 \(\operatorname{lim} = 0\) 时,表示前缀比上界小,后面的位数可以随意填。因此前缀性质相同的所有子问题是完全等价的,因此可以记忆化。
\(\operatorname{zero} = 1\) 与 \(\operatorname{lim} = 0\) 一定是配套出现的,因此也可以特判掉。
这样时空复杂度均变为了原来的 \(\frac{1}{4}\)。在其他题目中也可以套用此模板,将 0/1 维特判掉,减小时空复杂度。
可能有___出题人卡直接记忆化的写法,比如这题:
代码
引入问题的完整代码。
//知识点:数位 DP
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <vector>
#define LL long long
const int kN = 15;
//=============================================================
int numlth, f[kN][kN];
std::vector <int> num;
//=============================================================
inline int read() {
int f = 1, w = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar())
if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
void Chkmax(int &fir, int sec) {
if (sec > fir) fir = sec;
}
void Chkmin(int &fir, int sec) {
if (sec < fir) fir = sec;
}
int Dfs(int now_, int last_, bool zero_, bool lim_) {
if (now_ > numlth) return 1;
if (!lim_ && f[now_][last_] != -1) return f[now_][last_];
int ret = 0;
for (int i = 0, up = lim_ ? num[now_] : 9; i <= up; ++ i) {
if (abs(i - last_) < 2) continue ;
if (zero_ && !i) ret += Dfs(now_ + 1, 11, true, lim_ && i == up);
else ret += Dfs(now_ + 1, i, false, lim_ && i == up);
}
if (!lim_) f[now_][last_] = ret;
return ret;
}
int Calc(int val_) {
num.clear();
num.push_back(0);
for (int tmp = val_; tmp; tmp /= 10) num.push_back(tmp % 10);
for (int i = 1, j = num.size() - 1; i < j; ++ i, -- j) {
std::swap(num[i], num[j]);
}
numlth = num.size() - 1;
memset(f, -1, sizeof (f));
return Dfs(1, 11, true, true);
}
//=============================================================
int main() {
int a = read(), b = read();
printf("%d\n", Calc(b) - Calc(a - 1));
return 0;
}
例题
「ZJOI2010」数字计数
给定两个正整数 \(a\) 和 \(b\),求在 \([a,b]\) 中的所有整数中,每个数码各出现了多少次。
\(1\le a\le b\le 10^{12}\)。
1S,512MB。
与引入问题不同的是,这题要求的是数码的数量,限制了每个数的贡献,求贡献和。
套路类似,考虑对每个数码分开求解,dfs 时记录已枚举前缀的贡献量。
设 Dfs(int now_, LL sum_, bool zero_, bool lim_, int digit_)
表示前 \(\operatorname{now} - 1\) 位含有数码 \(\operatorname{digit}\) 的数量为 \(\operatorname{sum}\)、前缀是否全为前导零、前缀是否达到上界,满足上述条件的所有数中数码 \(\operatorname{digit}\) 的数量。
边界是搜索到第 \(\operatorname{length}+1\) 位,此时返回 \(\operatorname{sum}\) 的值。
与套路类似地,发现一些 \(\operatorname{now}\) 和 \(\operatorname{sum}\) 相等的搜索状态会被重复访问,简单记忆化即可。
总复杂度 \(O(10^2\log_{10}(n))\) 级别。
//知识点:数位 DP
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <vector>
#define LL long long
const int kN = 20;
//=============================================================
LL numlth, f[kN][kN];
std::vector <int> num;
//=============================================================
inline LL read() {
LL f = 1, w = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar())
if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
void Chkmax(int &fir, int sec) {
if (sec > fir) fir = sec;
}
void Chkmin(int &fir, int sec) {
if (sec < fir) fir = sec;
}
LL Dfs(int now_, LL sum_, bool zero_, bool lim_, int digit_) {
if (now_ > numlth) return sum_;
if (!lim_ && f[now_][sum_] != -1) return f[now_][sum_];
LL ret = 0;
for (int i = 0, up = lim_ ? num[now_] : 9; i <= up; ++ i) {
if (zero_ && !i) ret += Dfs(now_ + 1, sum_, true, lim_ && i == up, digit_);
else ret += Dfs(now_ + 1, sum_ + (i == digit_), false, lim_ && i == up, digit_);
}
if (!lim_) f[now_][sum_] = ret;
return ret;
}
LL Calc(LL val_, int digit_) {
num.clear();
num.push_back(0);
for (LL tmp = val_; tmp; tmp /= 10) num.push_back(tmp % 10);
for (int i = 1, j = num.size() - 1; i < j; ++ i, -- j) std::swap(num[i], num[j]);
numlth = num.size() - 1;
memset(f, -1, sizeof (f));
return Dfs(1, 0, true, true, digit_);
}
//=============================================================
int main() {
LL a = read(), b = read();
for (int i = 0; i <= 9; ++ i) printf("%lld ", Calc(b, i) - Calc(a - 1, i));
return 0;
}
还有一种考虑每个位置填入指定数码后对应的数的个数的无脑写法,看代码就能看懂。
//知识点:暴力
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#define LL long long
const int kN = 13;
//=============================================================
LL f[kN];
//=============================================================
inline LL read() {
LL f = 1, w = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar())
if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
void Chkmax(int &fir, int sec) {
if (sec > fir) fir = sec;
}
void Chkmin(int &fir, int sec) {
if (sec < fir) fir = sec;
}
LL Calc(LL val_, LL digit_) {
LL ret = (!digit_);
for (LL tmp = val_, pow10 = 1; tmp; tmp /= 10ll, pow10 *= 10ll) {
LL pre = tmp / 10ll + 1;
if (! digit_) {
if (pre == 1) continue;
if (0 < tmp % 10) ret += (pre - 1ll) * pow10;
if (0 == tmp % 10) ret += (pre - 2ll) * pow10 + val_ % pow10 + 1;
continue;
}
if (digit_ > tmp % 10) ret += (pre - 1ll) * pow10;
if (digit_ == tmp % 10) ret += (pre - 1ll) * pow10 + val_ % pow10 + 1;
if (digit_ < tmp % 10) ret += pre * pow10;
}
return ret;
}
//=============================================================
int main() {
LL a = read(), b = read();
for (int i = 0; i <= 9; ++ i) printf("%lld ", Calc(b, i) - Calc(a - 1, i));
return 0;
}
「AHOI2009」同类分布
给定两个正整数 \(a\) 和 \(b\),求在 \([a,b]\) 中的所有整数中,各位数之和能整除原数的数的个数。
\(1\le a\le b\le 10^{18}\)。
3S,512MB。
考虑到各位数之和与原数在 dfs 中都是变量,不易检验合法性。但发现各位数之和不大于 \(9\times 12\),考虑先枚举各位数之和,再在 dfs 时维护前缀的余数,以检查是否合法。
同样设 Dfs(int now_, int sum_, int p_, bool zero_, bool lim_, int val_)
,其中 \(\operatorname{sum}\) 为前缀的各数位之和,\(p\) 为原数模 \(\operatorname{val}\) 的余数。
边界是搜索到第 \(\operatorname{length}+1\) 位,此时返回 \([\operatorname{sum}=\operatorname{val} \land \, p = 0]\)。
对数位和和余数简单记忆化即可,总复杂度 \(O(2\cdot10^2\log_{10}^3(n))\) 级别。
//知识点:数位 DP
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <vector>
#define LL long long
const int kN = 20;
//=============================================================
int numlth;
LL f[kN][9 * kN][9 * kN];
std::vector <int> num;
//=============================================================
inline LL read() {
LL f = 1, w = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar())
if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
void Chkmax(int &fir, int sec) {
if (sec > fir) fir = sec;
}
void Chkmin(int &fir, int sec) {
if (sec < fir) fir = sec;
}
LL Dfs(int now_, int sum_, int p_, bool zero_, bool lim_, int val_) {
if (now_ > numlth) return (sum_ == val_ && !p_);
if (!lim_ && f[now_][sum_][p_] != -1) return f[now_][sum_][p_];
LL ret = 0;
for (int i = 0, up = lim_ ? num[now_] : 9; i <= up; ++ i) {
if (zero_ && !i) ret += Dfs(now_ + 1, sum_, 10 * p_ % val_, true, lim_ && i == up, val_);
else ret += Dfs(now_ + 1, sum_ + i, (10 * p_ + i) % val_, false, lim_ && i == up, val_);
}
if (!zero_ && !lim_) f[now_][sum_][p_] = ret;
return ret;
}
LL Calc(LL val_) {
num.clear();
num.push_back(0);
for (LL tmp = val_; tmp; tmp /= 10) num.push_back(tmp % 10);
for (int i = 1, j = numlth = num.size() - 1; i < j; ++ i, -- j) {
std::swap(num[i], num[j]);
}
LL ret = 0;
for (int i = 1; i <= 9 * numlth; ++ i) {
memset(f, -1, sizeof (f));
ret += Dfs(1, 0, 0, true, true, i);
}
// printf("%lld %lld\n", val_, ret);
return ret;
}
//=============================================================
int main() {
LL a = read(), b = read();
printf("%lld\n", Calc(b) - Calc(a - 1));
return 0;
}
套路题们
给定两个正整数 \(a\) 和 \(b\),求在 \([a,b]\) 中的所有整数中,存在长度至少为2的回文子串的数的个数。
\(1\le a< b\le 10^{1000}\)。
1S,128MB。
存在长度至少为2的回文子串等价于没有连续相等的三位,dfs 时记录前两位即可。代码 Link。
给定两个正整数 \(a\) 和 \(b\),求在 \([a,b]\) 中的所有整数中,至少有三个相邻的相同数字,且 8 和 4 不同时存在的数的个数。
\(10^{10}\le a\le b\le 10^{11}\)。
1S,256MB。
状态多设几维即可,记录前两位,前缀中是否有有三个相邻的相同数字,前缀中是否有 8,前缀中是否有 4。代码 Link。
给定一正整数 \(a\),求在 \([1,a]\) 中的所有整数的二进制拆分中 1 的个数的乘积。
\(1\le a \le 10^{15}\)。
1S,128MB。
二进制拆分 \(a\),同「AHOI2009」同类分布,枚举二进制中 1 的个数 dfs 即可。
注意不要乱取模。代码 Link。
「SDOI2014」数数
给定一个整数 \(n\),一大小为 \(m\) 的数字串集合 \(s\)。
求不以 \(s\) 中任意一个数字串作为子串的,不大于 \(n\) 的数字的个数。
\(1\le n\le 10^{1201}\),\(1\le m\le 100\),\(1\le \sum |s_i|\le 1500\)。\(n\) 没有前导零,\(s_i\) 可能存在前导零。
1S,128MB。
题目要求不以 \(s\) 中任意一个数字串作为子串,想到这题:「JSOI2007」文本生成器。首先套路地对给定集合的串构建 ACAM,并在 ACAM 上标记所有包含集合内的子串的状态。
之后考虑在 ACAM 上模拟串匹配的过程做数位 DP。发现前缀所在状态储存了前缀的所有信息,可以将其作为 dfs 的参数。
设 Dfs(int now_, int pos_, bool zero_, bool lim_) {
表示前缀匹配到的 ACAM 的状态为 \(\operatorname{pos}\) 时,合法的数字的数量。转移时沿 ACAM 上的转移函数转移,避免转移到被标记的状态。
存在 \(\operatorname{trans}(0, 0) = 0\),这样直接 dfs 也能顺便处理不同长度的数字串。
总复杂度 \(O(\log_{10}(n)\sum |s_i|)\) 级别。
//知识点:ACAM,数位 DP
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <queue>
#define LL long long
const int kN = 1500 + 10;
const int mod = 1e9 + 7;
//=============================================================
int n, m, ans;
char num[kN], s[kN];
//=============================================================
inline int read() {
int f = 1, w = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar())
if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
void Chkmax(int &fir, int sec) {
if (sec > fir) fir = sec;
}
void Chkmin(int &fir, int sec) {
if (sec < fir) fir = sec;
}
namespace ACAM {
const int kSigma = 10;
int node_num, tr[kN][kSigma], last[kN], fail[kN];
int f[kN][kN];
bool tag[kN];
void Insert(char *s_) {
int u_ = 0, lth = strlen(s_ + 1);
for (int i = 1; i <= lth; ++ i) {
if (! tr[u_][s_[i] - '0']) tr[u_][s_[i] - '0'] = ++ node_num;
u_ = tr[u_][s_[i] - '0'];
last[u_] = s_[i] - '0';
}
tag[u_] = true;
}
void Build() {
std:: queue <int> q;
for (int i = 0; i < kSigma; ++ i) {
if (tr[0][i]) q.push(tr[0][i]);
}
while (!q.empty()) {
int u_ = q.front(); q.pop();
tag[u_] |= tag[fail[u_]];
for (int i = 0; i < kSigma; ++ i) {
int v_ = tr[u_][i];
if (v_) {
fail[v_] = tr[fail[u_]][i];
q.push(v_);
} else {
tr[u_][i] = tr[fail[u_]][i];
}
}
}
}
int Dfs(int now_, int pos_, bool zero_, bool lim_) {
if (now_ > n) return 1;
if (!zero_ && !lim_ && f[now_][pos_] != -1) return f[now_][pos_];
int ret = 0;
for (int i = 0, up = lim_ ? num[now_] - '0': 9; i <= up; ++ i) {
int v_ = tr[pos_][i];
if (tag[v_]) continue;
if (zero_ && !i) ret += Dfs(now_ + 1, 0, true, lim_ && i == num[now_] - '0');
else ret += Dfs(now_ + 1, v_, false, lim_ && i == num[now_] - '0');
ret %= mod;
}
if (!zero_ && !lim_) f[now_][pos_] = ret;
return ret;
}
int DP() {
memset(f, -1, sizeof (f));
return Dfs(1, 0, true, true);
}
}
//=============================================================
int main() {
scanf("%s", num + 1);
n = strlen(num + 1);
m = read();
for (int i = 1; i <= m; ++ i) {
scanf("%s", s + 1);
ACAM::Insert(s);
}
ACAM::Build();
printf("%d\n", ACAM::DP());
return 0;
}
写在最后
鸣谢: