数位DP
概述
数位DP是一类与数字枚举有关的DP问题,往往无法暴力破解。
通常这类问题是要统计某个区间内符合一些性质的数,具有如下的特征:
- 要统计区间[l, r]内满足要求的数,往往可以转换成求[0,r] - [0,l)
- 对于求区间[0,n)有通用的方法
- 对于一个小于n的数,从高位到低位一定会出现有一位数字小于n的这一位数字
数字枚举存在大量的重复性,使用DP的方法可以减少很多计算。
我们使用dp[len][st]
数组来存储中间的结果。第一位len表示当前处理的位数,通常从高位到低位遍历。第二位st表示状态空间。状态空间的设计是数位DP的难点,可以不止是二维。
一般步骤
首先是初始化,包括dp数组的预处理、每一位极限值的处理等。
一般dp数组的处理可以直接重置成-1,而极限值则是n的各个数字,保证枚举时只会枚举比n小的数字。
有时根据题目也会做更多的预处理工作。
完成初始化之后便是dfs递归求解,枚举数的每一位,按位DP。一个通用的模板如下:
int dfs(int pos, int s, int e)
{
if(!pos) return s == target;
if(!e && f[pos][s] != -1)
return f[pos][s];
int ans = 0;
int u = e ? bit[pos] : 9;
for(int i = (first ? 1 : 0); i <= u; i++)
{
ans += dfs(pos - 1, new_s(s,i), e && i == u);
}
return e ? ans : f[pos][s] = ans;
}
这里我们假设初始化dp数组为-1,bit数组记录n的每一位数字。pos表示枚举到第几位了(从最高位向最低位枚举),s表示传递的状态值,e则是极限标识。
有几个注意的地方:
- line 3: pos 为 0 时表示枚举完了所有的位数,此时判断是否符合条件
- line 4: 中途遇到先前计算的结果时,返回结果
- line 7~8: 当前位的计算值与枚举上限
- line 9~12: 递归枚举下一位数字,同时需要更新s与e的值
- line 14: 根据e的值选择是否保存当前值,并返回这一步的计算结果ans
数位DP的难点在于状态空间的设计。DP的初衷在于减少重复的计算,状态空间便是用来保存中间值的,当我们觉得有些问题用现有的设计说不清楚的时候,便可以考虑增加状态,以细分状态空间。
Example 1:Amount of Degrees
题目:求给定区间[X, Y]
中满足下列条件的整数个数:这个数恰好等于K个互不相等的B的整数次幂的和。其中 2 <= B <= 10。
如指定[15, 20],K=2,B=2,有三个数满足:
17=24+20 18=24+21 20=24+22因此输出 3
这个例子来自于刘聪的《浅谈数位类统计问题》,特点在于额外的初始化工作。
思路:所求的数为互不相等的幂的和,那么其B进制表示的数字各位都只是0或者1,我们只需考虑二进制的情况,其他进制可以转化为二进制求解。
以 n=13,K=3为例,即是求出0到13(1101)中二进制表示含3个1的数的个数。为了方便思考,我们先画一棵完全二叉树:
树根用0表示,这颗二叉树便可以表示4位的所有二进制数,橙色路径是上限13,而小于n的数组成了三棵子树,统计小于13的数只需要统计这三棵子树。
对于一棵高度为i的完全二叉树,统计二进制表示含有j个1的树可以递推求出。每下一层,相对于上一层的数要么添0要么添1:
- 初始时有:
dp[0][0] = 1, dp[0][1] = 0
- 第i层至多有i个1
- 第i层j个1的计数为 i−1 层j个1添0或者 i−1 层 j−1 个1添1
接下来是进制转换的问题,对于询问n,我们需要找到不超过n的最大B进制表示只含0、1的数(同一个幂次只计算一次):找到n左起第一位非0、1的数位,将其变为1,并将右侧所有数位设置为1,就可以将B进制表示视为二进制查询。
- 如 1028(10),变换为1011,是最接近的不同的10的幂次的和
- 如20 (3),对应10进制的6,变换为11
- 如202 (3), 对应10进制的20,变换为101
于是有下面的代码:(默认输入合法)
#include <cstdio>
#include <cstring>
#include <string>
#include <iostream>
using namespace std;
int dp[32][32];
void init()
{
memset(dp, 0, sizeof(dp));
dp[0][0] = 1;
for(int i = 1; i < 32; i++)
{
dp[i][0] = dp[i - 1][0];
for(int j = 1; j <= i; j++)
{
dp[i][j] = dp[i - 1][j] + dp[i - 1][j - 1];
}
}
}
int calc(int x, int k)
{
if(x <= 0) return 0;
int tot = 0, ans = 0;
for(int i = 31; i > 0; i--)
{
if(x & (1<<i)) {="" tot="" ++;="" if(tot=""> k)
break;
x = x ^ (1 << i); // 将这一位置0
}
if((1 << (i - 1)) <= x)
{
ans += dp[i - 1][k - tot];
}
}
if(tot + x == k)
ans ++;
return ans;
}
int change(string s, int k, int sign)
{
if(sign)
{
int last = s.length() - 1;
for(; last > 0; last --)
{
if(s[last] > '0')
{
s[last] -= 1;
break;
}
else s[last] += k - 1;
}
}
if(k != 2)
{
int pos = 0;
while(s[pos] == '0' || s[pos] == '1')
pos ++;
for(; pos < s.length(); pos++)
s[pos] = '1';
}
return stoi(s, 0, 2); //返回成十进制的数以与calc函数对接
}
int main()
{
string a, b;
int k, B;
init();
// 输入B进制表示的a、b, 要求k个B的不同幂次
while(cin >> a >> b >> k >> B)
{
cout << calc(change(b, B, 0), k) - calc(change(a, B, 1), k) << endl;
}
return 0;
}
Example 2:hihocoder 1033 交错和
题目:给定一个数 x,设它十进制从高位到低位上的数位依次是 a0, a1, ..., an − 1,定义交错和函数:
f(x)=a0−a1+a2−a3+…+(−1)n−1an−1
例如:
f(3214567)=3−2+1−4+5−6+7=4
给定区间 [l,r] 和 k,0<=l<=r<=1018,|k|<=100,求区间内所有f(x)=k的数的和 (模1000000007)。
思路:这一道题目的状态空间应该和k的值有关,18位极限情况下k最多可以有81,考虑负数,dp[len][sum]
第二维开200即可。状态空间即是当前遍历到第len位时,已有的和为sum时的状态。
问题1:存在前导0的问题。
比如枚举01、001虽然都是1,但是计算结果却会是-1和+1,因此交错符号的更迭应该从第一个非零位开始。这样的话枚举状态就不止于len有关,还与当前的有效位有关。于是我们可以设计状态空间dp[len][len0][sum]
,来保存遍历到第len位,有效位为len0时和为sum的状态。
问题2:每个状态的值是什么?
最终我们要求和为k的数的和,最容易想到的思路是在状态里直接存和。然而由于题目的性质,我们枚举下一位时只考虑新的这位对总和的影响,而当前的总和可能有多种方式得到,比如123____
和402____
,它们的当前和都是6,下一位都该枚举第四位(最高位为第七位)。
对于当前sum和一样的情况,后续的处理是等效的,可以利用DP减少计算。最后计算时,是从最低位逐步上升到最高位的,因此这里可以继承的是一定sum值时后续的情况。
就像是把茫茫多的大数按位拆分了,这些大数直接求和的结果必然和按位分组求和的和一样。而从最低位逐步分组求和时便可以节省同样sum值的计算,直接获取先前的结果。
可以看到,基于DP的思路,是不会具体到枚举每一个数的,也就无法直接得到这些数的和,因为每一步汇聚都是汇聚的拥有公共前段sum和的数的后段。
于是我们将dp设计为结构体,存储当前的后段和与符合这种后段的数的个数。
最终代码如下:
#include <stdio.h>
#include <string.h>
#include <algorithm>
const long long mod = 1000000007;
const int maxlen = 20;
const int radiu = 100;
struct node
{
long long s, n;
} dp[maxlen][maxlen][radiu << 1];
long long a, b, base[maxlen];
int bit[maxlen], k;
node dfs(int len, int len0, int sum, bool zero, bool fp)
{
node ans = {0, 0}, temp;
if(len < 0)
{
if(sum == k) ans.n = 1;
return ans;
}
if(!zero && !fp && dp[len][len0][sum + radiu].n != -1)
return dp[len][len0][sum + radiu];
int fpmax = fp ? bit[len] : 9;
for(int i = 0; i <= fpmax; i++)
{
if(zero)
{
if(!i)
temp = dfs(len - 1, 0, 0, true, fp && (i == fpmax));
else
temp = dfs(len - 1, 1, i, false, fp && (i == fpmax));
}
else
{
temp = dfs(len - 1, len0 + 1, sum + (len0 & 1 ? -i : i), false, fp && (i == fpmax));
}
ans.n = (ans.n + temp.n) % mod;
ans.s = (ans.s + temp.s + temp.n * i * base[len] % mod) % mod;
}
if(!fp && !zero)
dp[len][len0][sum + radiu] = ans;
return ans;
}
long long solve(long long n)
{
if(n <= 0) return 0;
int len = 0;
while(n)
{
bit[len ++] = n % 10;
n /= 10;
}
return dfs(len - 1, 0, 0, true, true).s;
}
int main()
{
base[0] = 1;
for(int i = 1; i < maxlen; i++)
base[i] = base[i - 1] * 10 % mod;
memset(dp, -1, sizeof(dp));
while(scanf("%lld %lld %d", &a, &b, &k) != EOF)
{
printf("%lld\n", (solve(b) - solve(a - 1) + mod) % mod);
}
return 0;
}
Example 3:神奇的数
题目:给定区间[l,r],0<=l<=r<=1018,求区间内同时满足下列条件的数的个数:
- 至少含有2、3、5中的一个数字
- 不包含
18
- 能够被7整除
这道题目来自 网易互娱2017实习生招聘游戏研发工程师在线笔试第二场
,由于没有加入题库,无法测试是否AC,所以下面的讨论和代码可能有误。
思路:三个条件相对独立,每个条件单独的都有数位DP的例题,所以这里的处理方式是针对不同的条件设置了不同的状态空间,每个条件的判定相对容易。
其余的步骤比较典型,没有什么特别的处理。
代码:
#include <stdio.h>
#include <string.h>
#include <algorithm>
using namespace std;
int dp[20][2][2][7];
int digit[20];
long long dfs(int len,bool state, int has, int sum, bool fp)
{
if(!len)
{
if(has == 1 && sum == 0)
{
return 1;
}
else return 0;
}
if(!fp && dp[len][state][has][sum] != -1)
{
return dp[len][state][has][sum];
}
long long ret = 0 , fpmax = fp ? digit[len] : 9;
for(int i = 0; i <= fpmax; i++)
{
// 不含18
if(state && i == 8)
continue;
//含有2,3,5
int prehas = has;
int presum = sum;
sum *= 10;
sum += i;
sum %= 7;
if(!has && (i == 2 || i == 3 || i == 5))
{
has = 1;
}
ret += dfs(len - 1,i == 1, has, sum, fp && i == fpmax);
has = prehas;
sum = presum;
}
if(!fp)
{
dp[len][state][has][sum] = ret;
}
return ret;
}
long long f(long long n)
{
long long len = 0;
while(n)
{
digit[++len] = n % 10;
n /= 10;
}
return dfs(len,false, 0, 0, true);
}
int main()
{
long long a,b;
memset(dp, -1, sizeof(dp));
while(scanf("%lld %lld", &a, &b) != EOF)
{
printf("%lld\n", f(b) - f(a - 1));
}
return 0;
}
https://agatelee.cn/2016/04/%E6%95%B0%E4%BD%8Ddp/