初探数位dp

  前言:这是蒟蒻第一次写算法系列,请诸位大佬原谅文笔与排版。

 


 

一、导入

  在刷题的时候,我们有时会见到这样一类问题:在区间$[l,r]$内,共有多少个整数满足某种条件。如果$l$和$r$间的差很小,我们可以考虑暴力枚举直接判断。然而,若$l<=r<=10^9$甚至更大呢?

  这时往往就可以用到一种dp方式:数位dp。


 

二、做法:

  这里先放一道例题:1026: [SCOI2009]windy数

  题意:求在区间$[l,r]$内,满足相邻数位的差>=2的整数的个数。

  首先我们可以发现,$[l,r]$的答案=$[1,r]$的答案-$[1,l)$的答案。于是我们可以把问题转化为求$[1,r]$和$[1,l-1]$的答案。因为$l<=r<=2*10^9$,所以暴力枚举肯定行不通。但是我们可以发现这道题中整数需满足的条件只与相邻数位有关,这启示我们,也许可以按位dp?

  我们先来看一张经典的图(表示区间$[0,22]$):

  

 

  这幅图中把正整数按位用树的形式表示,那么区间$[0,x]$(这里x=22)就可以拆成多棵满10叉树(即图中的蓝圈),而且因为每层所用的树个数不会超过10棵(0~9),总共有$\log_{10}{x}$层,则树的个数规模为$O(10\log{x})$。

  那么单棵满10叉树的答案怎么求呢?我们仔细观察这棵树,那么就可以发现每棵满10叉树表示的数是位数相同(等于它从下往上数所处的层数),最高位相同(等于根节点表示的数),且该树的答案只与以树根的10个儿子为根的,10棵子树的答案有关。并且在整棵树中,处在同一层的,且子树根节点表示的数相同的树(即位数相同,最高位相同),它们是等价的。于是我们就可以直接设$f[i][j]$表示有i位,最高位为j的满足条件的整数的个数,然后xjb转移。于是就可以优哉游哉地dp, 然后按图统计答案了。

  不过这道题还是比较麻烦,因为需要排除前导零的影响,不过核心思想还是上面的那样,然后再分位数统计就好了。

  代码(时间复杂度$O(10^2\log{r})$):

#include<cstdio>
#include<cmath>
#include<cstdlib>
#include<cstring>
#include<ctime>
#include<algorithm>
#include<queue>
#include<vector>
#include<map>
#define ll long long
#define ull unsigned long long
#define max(a,b) (a>b?a:b)
#define min(a,b) (a<b?a:b)
#define lowbit(x) (x& -x)
#define mod 1000000007
#define inf 0x3f3f3f3f
#define eps 1e-18
#define maxn 100020
inline ll read(){ll tmp=0; char c=getchar(),f=1; for(;c<'0'||'9'<c;c=getchar())if(c=='-')f=-1; for(;'0'<=c&&c<='9';c=getchar())tmp=(tmp<<3)+(tmp<<1)+c-'0'; return tmp*f;}
inline ll power(ll a,ll b){ll ans=0; for(;b;b>>=1){if(b&1)ans=ans*a%mod; a=a*a%mod;} return ans;}
inline ll gcd(ll a,ll b){return b?gcd(b,a%b):a;}
inline void swap(int &a,int &b){int tmp=a; a=b; b=tmp;}
using namespace std;
int f[20][10];
int l,r;
int dp(int x)
{
    int num[20],len=0;
    for(;x;x/=10)num[++len]=x%10;
    for(int i=0;i<=9;i++)f[1][i]=1;//预处理 
    for(int i=2;i<=len;i++)
        for(int j=0;j<=9;j++){
            f[i][j]=0;
            for(int k=0;k<=9;k++)
                if(abs(j-k)>=2)
                    f[i][j]+=f[i-1][k];
        }//处理f数组,f[i][j]表示有i位,最高位为j的的windy数个数*/ 
    int ans=0;
    for(int i=1;i<len;i++)
        for(int j=1;j<=9;j++)
            ans+=f[i][j];//统计位数小于len的数一定小于n,直接加上 
    for(int i=len;i;i--){
        int l=(i==len)?1:0,r=(i==1)?num[i]:num[i]-1;//不含前导零,所以最高位不能取0,从1开始枚举,否则从0开始
        //除个位以外,因当前位若取num[i]可能超出1~n的范围,所以只能取到num[i]-1;因为询问1~n时包含n,所以个位的上限要取num[i]
        for(int j=l;j<=r;j++)
            if(i==len||abs(j-num[i+1])>=2)ans+=f[i][j];
        if(i<len&&abs(num[i]-num[i+1])<2)break;//统计下一位时,这一位取的是num[i],若会和上一位num[i+1]发生冲突,则不可能出现windy数,直接break掉 
    }
    return ans;
}
int main()
{
    l=read(); r=read();
    printf("%d\n",dp(r)-dp(l-1));
}
bzoj1026

 


 

三、归纳:

  数位dp的特征:数据规模大,统计整数个数,被统计的数满足的条件往往与数位之间的关系或数位间的运算有关。

  基本做法:差分,先按位dp出所需数据($f[i][S]$->i位数,状态为S),然后再拆分原区间,用dp出的数据统计。

 


 

四、其他例题:

1、【bzoj1833】[ZJOI2010] count 数字计数

  裸的数位dp,分别计算每个数字出现的次数,做法和上面类似。

  代码:

#include<cstdio>
#include<cmath>
#include<cstdlib>
#include<cstring>
#include<ctime>
#include<iostream> 
#include<algorithm>
#include<queue>
#include<vector>
#include<map>
#define ll long long
#define ull unsigned long long
#define max(a,b) (a>b?a:b)
#define min(a,b) (a<b?a:b)
#define lowbit(x) (x& -x)
#define mod 1000000007
#define inf 0x3f3f3f3f
#define eps 1e-18
#define maxn 100010
inline ll read(){ll tmp=0; char c=getchar(),f=1; for(;c<'0'||'9'<c;c=getchar())if(c=='-')f=-1; for(;'0'<=c&&c<='9';c=getchar())tmp=(tmp<<3)+(tmp<<1)+c-'0'; return tmp*f;}
inline ll power(ll a,ll b){ll ans=1; for(;b;b>>=1){if(b&1)ans=ans*a%mod; a=a*a%mod;} return ans;}
inline ll gcd(ll a,ll b){return b?gcd(b,a%b):a;}
inline void swap(int &a,int &b){int tmp=a; a=b; b=tmp;}
using namespace std;
ll f[20][10][10],base[20];
ll l,r;
void prework()
{
    base[0]=1;
    for(int i=1;i<=13;i++)
        base[i]=base[i-1]*10;
    for(int i=0;i<=9;i++)
        f[1][i][i]=1;
    for(int i=2;i<=13;i++)
        for(int j=0;j<=9;j++){
            ll x=f[i-1][0][0]+f[i-1][0][1]*9;
            for(int k=0;k<=9;k++){
                f[i][j][k]=(j==k?base[i-1]:0)+x;
            }
        }
}
ll solve(ll n,int num)
{
    if(n<0)return 0;
    ll tmp=++n;//这里++n是为了把闭区间转化为开区间,因为下面求解时1~n的答案并不包括n。。
    int a[20],len=0;
    for(;tmp;tmp/=10)a[++len]=tmp%10;
    for(int i=1;i<len;i++)
        for(int j=1;j<=9;j++)
            ans+=f[i][j][num];
    for(int i=len;i;i--){
        for(int j=(i==len?1:0);j<a[i];j++)
            ans+=f[i][j][num];
        n-=a[i]*base[i-1];
        if(a[i]==num)ans+=n;
    }
    return ans;
}
int main()
{
    prework();
    l=read(); r=read();
    for(int i=0;i<9;i++)
        printf("%lld ",solve(r,i)-solve(l-1,i));
    printf("%lld\n",solve(r,9)-solve(l-1,9));
}
bzoj1833

 

posted @ 2018-06-30 21:28  QuartZ_Z  阅读(271)  评论(5编辑  收藏  举报