初探数位dp
前言:这是蒟蒻第一次写算法系列,请诸位大佬原谅文笔与排版。
一、导入
在刷题的时候,我们有时会见到这样一类问题:在区间$[l,r]$内,共有多少个整数满足某种条件。如果$l$和$r$间的差很小,我们可以考虑暴力枚举直接判断。然而,若$l<=r<=10^9$甚至更大呢?
这时往往就可以用到一种dp方式:数位dp。
二、做法:
这里先放一道例题:1026: [SCOI2009]windy数。
题意:求在区间$[l,r]$内,满足相邻数位的差>=2的整数的个数。
首先我们可以发现,$[l,r]$的答案=$[1,r]$的答案-$[1,l)$的答案。于是我们可以把问题转化为求$[1,r]$和$[1,l-1]$的答案。因为$l<=r<=2*10^9$,所以暴力枚举肯定行不通。但是我们可以发现这道题中整数需满足的条件只与相邻数位有关,这启示我们,也许可以按位dp?
我们先来看一张经典的图(表示区间$[0,22]$):
这幅图中把正整数按位用树的形式表示,那么区间$[0,x]$(这里x=22)就可以拆成多棵满10叉树(即图中的蓝圈),而且因为每层所用的树个数不会超过10棵(0~9),总共有$\log_{10}{x}$层,则树的个数规模为$O(10\log{x})$。
那么单棵满10叉树的答案怎么求呢?我们仔细观察这棵树,那么就可以发现每棵满10叉树表示的数是位数相同(等于它从下往上数所处的层数),最高位相同(等于根节点表示的数),且该树的答案只与以树根的10个儿子为根的,10棵子树的答案有关。并且在整棵树中,处在同一层的,且子树根节点表示的数相同的树(即位数相同,最高位相同),它们是等价的。于是我们就可以直接设$f[i][j]$表示有i位,最高位为j的满足条件的整数的个数,然后xjb转移。于是就可以优哉游哉地dp, 然后按图统计答案了。
不过这道题还是比较麻烦,因为需要排除前导零的影响,不过核心思想还是上面的那样,然后再分位数统计就好了。
代码(时间复杂度$O(10^2\log{r})$):
#include<cstdio> #include<cmath> #include<cstdlib> #include<cstring> #include<ctime> #include<algorithm> #include<queue> #include<vector> #include<map> #define ll long long #define ull unsigned long long #define max(a,b) (a>b?a:b) #define min(a,b) (a<b?a:b) #define lowbit(x) (x& -x) #define mod 1000000007 #define inf 0x3f3f3f3f #define eps 1e-18 #define maxn 100020 inline ll read(){ll tmp=0; char c=getchar(),f=1; for(;c<'0'||'9'<c;c=getchar())if(c=='-')f=-1; for(;'0'<=c&&c<='9';c=getchar())tmp=(tmp<<3)+(tmp<<1)+c-'0'; return tmp*f;} inline ll power(ll a,ll b){ll ans=0; for(;b;b>>=1){if(b&1)ans=ans*a%mod; a=a*a%mod;} return ans;} inline ll gcd(ll a,ll b){return b?gcd(b,a%b):a;} inline void swap(int &a,int &b){int tmp=a; a=b; b=tmp;} using namespace std; int f[20][10]; int l,r; int dp(int x) { int num[20],len=0; for(;x;x/=10)num[++len]=x%10; for(int i=0;i<=9;i++)f[1][i]=1;//预处理 for(int i=2;i<=len;i++) for(int j=0;j<=9;j++){ f[i][j]=0; for(int k=0;k<=9;k++) if(abs(j-k)>=2) f[i][j]+=f[i-1][k]; }//处理f数组,f[i][j]表示有i位,最高位为j的的windy数个数*/ int ans=0; for(int i=1;i<len;i++) for(int j=1;j<=9;j++) ans+=f[i][j];//统计位数小于len的数一定小于n,直接加上 for(int i=len;i;i--){ int l=(i==len)?1:0,r=(i==1)?num[i]:num[i]-1;//不含前导零,所以最高位不能取0,从1开始枚举,否则从0开始 //除个位以外,因当前位若取num[i]可能超出1~n的范围,所以只能取到num[i]-1;因为询问1~n时包含n,所以个位的上限要取num[i] for(int j=l;j<=r;j++) if(i==len||abs(j-num[i+1])>=2)ans+=f[i][j]; if(i<len&&abs(num[i]-num[i+1])<2)break;//统计下一位时,这一位取的是num[i],若会和上一位num[i+1]发生冲突,则不可能出现windy数,直接break掉 } return ans; } int main() { l=read(); r=read(); printf("%d\n",dp(r)-dp(l-1)); }
三、归纳:
数位dp的特征:数据规模大,统计整数个数,被统计的数满足的条件往往与数位之间的关系或数位间的运算有关。
基本做法:差分,先按位dp出所需数据($f[i][S]$->i位数,状态为S),然后再拆分原区间,用dp出的数据统计。
四、其他例题:
1、【bzoj1833】[ZJOI2010] count 数字计数
裸的数位dp,分别计算每个数字出现的次数,做法和上面类似。
代码:
#include<cstdio> #include<cmath> #include<cstdlib> #include<cstring> #include<ctime> #include<iostream> #include<algorithm> #include<queue> #include<vector> #include<map> #define ll long long #define ull unsigned long long #define max(a,b) (a>b?a:b) #define min(a,b) (a<b?a:b) #define lowbit(x) (x& -x) #define mod 1000000007 #define inf 0x3f3f3f3f #define eps 1e-18 #define maxn 100010 inline ll read(){ll tmp=0; char c=getchar(),f=1; for(;c<'0'||'9'<c;c=getchar())if(c=='-')f=-1; for(;'0'<=c&&c<='9';c=getchar())tmp=(tmp<<3)+(tmp<<1)+c-'0'; return tmp*f;} inline ll power(ll a,ll b){ll ans=1; for(;b;b>>=1){if(b&1)ans=ans*a%mod; a=a*a%mod;} return ans;} inline ll gcd(ll a,ll b){return b?gcd(b,a%b):a;} inline void swap(int &a,int &b){int tmp=a; a=b; b=tmp;} using namespace std; ll f[20][10][10],base[20]; ll l,r; void prework() { base[0]=1; for(int i=1;i<=13;i++) base[i]=base[i-1]*10; for(int i=0;i<=9;i++) f[1][i][i]=1; for(int i=2;i<=13;i++) for(int j=0;j<=9;j++){ ll x=f[i-1][0][0]+f[i-1][0][1]*9; for(int k=0;k<=9;k++){ f[i][j][k]=(j==k?base[i-1]:0)+x; } } } ll solve(ll n,int num) { if(n<0)return 0; ll tmp=++n;//这里++n是为了把闭区间转化为开区间,因为下面求解时1~n的答案并不包括n。。 int a[20],len=0; for(;tmp;tmp/=10)a[++len]=tmp%10; for(int i=1;i<len;i++) for(int j=1;j<=9;j++) ans+=f[i][j][num]; for(int i=len;i;i--){ for(int j=(i==len?1:0);j<a[i];j++) ans+=f[i][j][num]; n-=a[i]*base[i-1]; if(a[i]==num)ans+=n; } return ans; } int main() { prework(); l=read(); r=read(); for(int i=0;i<9;i++) printf("%lld ",solve(r,i)-solve(l-1,i)); printf("%lld\n",solve(r,9)-solve(l-1,9)); }