hdoj4507(数位dp)

题目链接:https://vjudge.net/problem/HDU-4507

题意:定义如果一个整数符合下面3个条件之一,那么我们就说这个整数和7有关——
  1、整数中某一位是7;
  2、整数的每一位加起来的和是7的整数倍;
  3、这个整数是7的整数倍;

  给定l,r,求[l,r] 区间与7无关的数的平方和。

思路:这3条定义都是常规的数位dp,但题目求的并不是与7无关的数的个数,而是平方和,这也是该题的难点。这里需要数位dp维护3个值:

  1. 与7无关的数的个数num;

  2. 与7无关的数的和sum;

  3. 与7无关的数的平方和sum;

  (用结构体组织上述3个属性,假设当前求的结点是ans,当前点取i,递归得到的结点为tmp)

  第1条很简单,就是常规的数位dp:

    ans.num+=tmp.num;
    ans.num%=Mod;

  第2条需要用到第一条,求当前的所有数的和,即后面位的所有和+当前值×后面的数的个数:

    ans.sum+=(tmp.sum+(i*key[pos])%Mod*tmp.num%Mod)%Mod;
    ans.sum%=Mod;

  第3条需要用到前2条,求当前所有数的平方和。先考虑一个数,设该数后面的位数的值为b,当前要加的值为a(a=i*key[pos]),则该数平方和为(a+b)^2=a^2+2*a*b+b^2,然后考虑对所有数的平方和,即上述式子求和tmp.num次,上述式子有3项。第一项:即a^2*tmp.num。第二项:即2*a*tmp.sum。第三项:即tmp.sqsum。

        ans.sqsum+=tmp.num*i%Mod*i%Mod*key[pos]%Mod*key[pos]%Mod;
        ans.sqsum%=Mod;
        ans.sqsum+=2*i*key[pos]%Mod*tmp.sum%Mod;
        ans.sqsum%=Mod;
        ans.sqsum+=tmp.sqsum;
        ans.sqsum%=Mod;    

 

还有要注意的是这道题的数据,因为num,sum,sqsum还有key[i]都可能超过Mod,所以每乘一次就要%Mod,不然会出现乘法溢出。

AC代码:

#include <cstdio>
using namespace std;
typedef long long LL;
const LL Mod=1000000007;

struct node{
    LL num,sum,sqsum;
}dp[20][10][10];

int T,a[20];
LL key[20];

node dfs(int pos,int pre1,int pre2,bool limit){
    if(pos==-1){
        node tmp;
        tmp.num=(pre1!=0&&pre2!=0);
        tmp.sum=tmp.sqsum=0;
        return tmp;
    }
    if(!limit&&dp[pos][pre1][pre2].num!=-1)
        return dp[pos][pre1][pre2];
    int up=limit?a[pos]:9;
    node ans;
    ans.num=ans.sum=ans.sqsum=0;
    for(int i=0;i<=up;++i){
        if(i==7) continue;
        node tmp=dfs(pos-1,(pre1+i)%7,(pre2*10+i)%7,limit&&i==a[pos]);
        
        ans.num+=tmp.num;
        ans.num%=Mod;

        ans.sum+=(tmp.sum+(i*key[pos])%Mod*tmp.num%Mod)%Mod;
        ans.sum%=Mod;

        ans.sqsum+=tmp.num*i%Mod*i%Mod*key[pos]%Mod*key[pos]%Mod;
        ans.sqsum%=Mod;
        ans.sqsum+=2*i*key[pos]%Mod*tmp.sum%Mod;
        ans.sqsum%=Mod;
        ans.sqsum+=tmp.sqsum;
        ans.sqsum%=Mod;
    }
    if(!limit) dp[pos][pre1][pre2]=ans;
    return ans;
}

LL solve(LL x){
    int pos=0;
    while(x){
        a[pos++]=x%10;
        x/=10;
    }
    return dfs(pos-1,0,0,true).sqsum;
}

int main()
{
    key[0]=1;
    for(int i=1;i<=18;++i)
        key[i]=(key[i-1]*10)%Mod;
    for(int i=0;i<20;++i)
        for(int j=0;j<10;++j)
            for(int k=0;k<10;++k)
                dp[i][j][k].num=-1;
    scanf("%d",&T);
    while(T--){
        LL l,r,ans;
        scanf("%lld%lld",&l,&r);
        ans=solve(r)-solve(l-1);
        ans=(ans%Mod+Mod)%Mod;
        printf("%lld\n",ans);
    }
    return 0;
}

 

posted @ 2019-05-08 22:02  Frank__Chen  阅读(187)  评论(0编辑  收藏  举报