HDU4507 吉哥系列故事——恨7不成妻 题解 数位DP

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4507

题目大意:
找到区间 \([L,R]\) 范围内所有满足如下条件的数的 平方和

  • 不包含‘7’;
  • 不能被 7 整除;
  • 各位之和不能被 7 整除。

注意:求的是满足条件的数的 平方和

解题思路:
使用 数位DP 尽情求解。
但是因为这里求的是满足要求的元素的平方和,而不是元素的个数。
所以我们不能简单地开long long来存放结果,
而是开一个结构体来存放结果,结构体中需要包含三个元素:

  • 满足要求的元素个数(用 cnt 表示);
  • 满足要求的元素的和(用 sum 表示);
  • 满足要求的元素的平方和(用 sum2 表示)

定义状态 \(f[pos][pre][ts]\) 表示:

  • 当前在第 \(pos\) 位;
  • 前一位(即第 \(pos+1\) 位)放置的数 mod 7为 \(pre\)
  • 前面所有位上的数的和 mod 7 的结果为 \(ts\)

时对应的信息(包括元素个数、和、平方和)。

那么如何获得状态转移方程呢?

每次查找的数字的前几位都是一样的,比如对于三位数:

\[abc \]

假设第一位枚举了2;即

\[2|b|c \]

那么之后枚举到的数字就是 \(200+x\)

如果 \(200\) - \(300\) 之间只有 \(231\)\(230\)\(233\) 满足条件,那么 \(200\) - \(300\) 之间的数的平方和就为

\[230^2+231^2+233^2 \]

也就等于

\[(200+30)^2+(200+31)^2+(200+33)^2 \]

展开后得到:

\[3 \times 200^2+2 \times 200 \times (30+31+33)+(30^2+31^2+3^32) \]

并且我们可以据此得到状态转移了。

(上述思路来自 Ender的博客 ,不过他的公式推到错了,多乘了一个3,所以还是看我的计算公式即可)

实现的代码如下:

#include <bits/stdc++.h>
using namespace std;
const long long MOD = 1000000007LL;
struct Node {
    long long cnt;    // 数量
    long long sum;    // 和
    long long sum2;   // 平方和
    Node () {};
    Node (long long _cnt, long long _sum, long long _sum2) {
        cnt = _cnt; sum = _sum; sum2 = _sum2;
    }
} f[22][7][7];
int a[22];
bool vis[22][7][7];
long long pow10[22];
void init() {
    memset(vis, 0, sizeof(vis));
    pow10[0] = 1;
    for (int i = 1; i < 22; i ++) pow10[i] = pow10[i-1] * 10 % MOD;
}
Node dfs(int pos, int pre, int ts, bool limit) {
    if (pos < 0) {
        int cnt = 1;
        if (pre == 0) cnt = 0;  // 不能被7整除
        if (ts == 0) cnt = 0;   // 所有数位的和不能被7整除
        return Node(cnt, 0, 0);
    }
    if (!limit && vis[pos][pre][ts]) return f[pos][pre][ts];
    int up = limit ? a[pos] : 9;
    Node res = Node(0, 0, 0);
    for (int i = 0; i <= up; i ++) {
        if (i == 7) continue;   // 不能包含7
        Node tmp = dfs(pos-1, (pre*10+i)%7, (ts+i)%7, limit && i==up);
        long long tmp_cnt = tmp.cnt;
        long long tmp_sum = tmp.sum;
        long long tmp_sum2 = tmp.sum2;
        long long t = pow10[pos] * i % MOD;
        long long now_sum = (t * tmp_cnt + tmp_sum) % MOD;
        long long now_sum2 = (tmp_cnt * t % MOD * t % MOD + 2LL * t % MOD * tmp_sum % MOD + tmp_sum2) % MOD;
        res.cnt = (res.cnt + tmp_cnt) % MOD;
        res.sum = (res.sum + now_sum) % MOD;
        res.sum2 = (res.sum2 + now_sum2) % MOD;
    }
    if (!limit && !vis[pos][pre][ts]) {
        vis[pos][pre][ts] = true;
        f[pos][pre][ts] = res;
    }
    return res;
}
long long get_num(long long x) {
    int pos = 0;
    while (x) {
        a[pos++] = x % 10;
        x /= 10;
    }
    return dfs(pos-1, 0, 0, true).sum2;
}
int T;
long long L, R;
int main() {
    init();
    scanf("%d", &T);
    while (T --) {
        scanf("%lld%lld", &L, &R);
        printf("%lld\n", (get_num(R) - get_num(L-1) + MOD) % MOD);
    }
    return 0;
}
posted @ 2019-12-03 14:04  quanjun  阅读(181)  评论(0编辑  收藏  举报