HDU4507 吉哥系列故事——恨7不成妻 题解 数位DP
题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4507
题目大意:
找到区间 \([L,R]\) 范围内所有满足如下条件的数的 平方和 :
- 不包含‘7’;
- 不能被 7 整除;
- 各位之和不能被 7 整除。
注意:求的是满足条件的数的 平方和 !
解题思路:
使用 数位DP 尽情求解。
但是因为这里求的是满足要求的元素的平方和,而不是元素的个数。
所以我们不能简单地开long long来存放结果,
而是开一个结构体来存放结果,结构体中需要包含三个元素:
- 满足要求的元素个数(用
cnt
表示); - 满足要求的元素的和(用
sum
表示); - 满足要求的元素的平方和(用
sum2
表示)
定义状态 \(f[pos][pre][ts]\) 表示:
- 当前在第 \(pos\) 位;
- 前一位(即第 \(pos+1\) 位)放置的数 mod 7为 \(pre\)
- 前面所有位上的数的和 mod 7 的结果为 \(ts\)
时对应的信息(包括元素个数、和、平方和)。
那么如何获得状态转移方程呢?
每次查找的数字的前几位都是一样的,比如对于三位数:
\[abc
\]
假设第一位枚举了2;即
\[2|b|c
\]
那么之后枚举到的数字就是 \(200+x\) 。
如果 \(200\) - \(300\) 之间只有 \(231\) ,\(230\) 和 \(233\) 满足条件,那么 \(200\) - \(300\) 之间的数的平方和就为
\[230^2+231^2+233^2
\]
也就等于
\[(200+30)^2+(200+31)^2+(200+33)^2
\]
展开后得到:
\[3 \times 200^2+2 \times 200 \times (30+31+33)+(30^2+31^2+3^32)
\]
并且我们可以据此得到状态转移了。
(上述思路来自 Ender的博客 ,不过他的公式推到错了,多乘了一个3,所以还是看我的计算公式即可)
实现的代码如下:
#include <bits/stdc++.h>
using namespace std;
const long long MOD = 1000000007LL;
struct Node {
long long cnt; // 数量
long long sum; // 和
long long sum2; // 平方和
Node () {};
Node (long long _cnt, long long _sum, long long _sum2) {
cnt = _cnt; sum = _sum; sum2 = _sum2;
}
} f[22][7][7];
int a[22];
bool vis[22][7][7];
long long pow10[22];
void init() {
memset(vis, 0, sizeof(vis));
pow10[0] = 1;
for (int i = 1; i < 22; i ++) pow10[i] = pow10[i-1] * 10 % MOD;
}
Node dfs(int pos, int pre, int ts, bool limit) {
if (pos < 0) {
int cnt = 1;
if (pre == 0) cnt = 0; // 不能被7整除
if (ts == 0) cnt = 0; // 所有数位的和不能被7整除
return Node(cnt, 0, 0);
}
if (!limit && vis[pos][pre][ts]) return f[pos][pre][ts];
int up = limit ? a[pos] : 9;
Node res = Node(0, 0, 0);
for (int i = 0; i <= up; i ++) {
if (i == 7) continue; // 不能包含7
Node tmp = dfs(pos-1, (pre*10+i)%7, (ts+i)%7, limit && i==up);
long long tmp_cnt = tmp.cnt;
long long tmp_sum = tmp.sum;
long long tmp_sum2 = tmp.sum2;
long long t = pow10[pos] * i % MOD;
long long now_sum = (t * tmp_cnt + tmp_sum) % MOD;
long long now_sum2 = (tmp_cnt * t % MOD * t % MOD + 2LL * t % MOD * tmp_sum % MOD + tmp_sum2) % MOD;
res.cnt = (res.cnt + tmp_cnt) % MOD;
res.sum = (res.sum + now_sum) % MOD;
res.sum2 = (res.sum2 + now_sum2) % MOD;
}
if (!limit && !vis[pos][pre][ts]) {
vis[pos][pre][ts] = true;
f[pos][pre][ts] = res;
}
return res;
}
long long get_num(long long x) {
int pos = 0;
while (x) {
a[pos++] = x % 10;
x /= 10;
}
return dfs(pos-1, 0, 0, true).sum2;
}
int T;
long long L, R;
int main() {
init();
scanf("%d", &T);
while (T --) {
scanf("%lld%lld", &L, &R);
printf("%lld\n", (get_num(R) - get_num(L-1) + MOD) % MOD);
}
return 0;
}