P2359 三素数数 , 线性dp
题目背景
蛟川书院的一道练习题QAQ
题目描述
如果一个数的所有连续三位数字都是大于100的素数,则该数称为三素数数。比如113797是一个6位的三素数数。
输入格式
一个整数n(3 ≤ n ≤ 10000),表示三素数数的位数。
输出格式
一个整数,表示n位三素数的个数m,要求输出m除以10^9 + 9的余数。
输入输出样例
输入 #1复制
4
输出 #1复制
204
说明/提示
区域动归QAQ
解析:
第一次的错误做法:
f[i] 表示前 i 为的三素数的个数,f[i]=f[i-3]*t+f[i-2]+f[i-1], t 表示 1 到 1e3 内的素数的个数
这个做法是错误的,题目的意思应该是任意三个连续的数组成的三位数一定是素数,上述的做法只考虑了当前连续的三个数,而非任意任意三个连续的数,所以上述做法是错误的
正确的做法:
最容易,最直接的划分方式:f[i][j][k][l] 表示前 i 位,最近的三位数,百位为 j ,十位为 k,个位为 l 的三素数个的个数
状态转移方程:f[i][j][k][l]=(f[i-1][k][l][p]+f[i][j][k][l])%mod;
初始化 f[3][j][k][l]=1;
时间复杂度为O(1e3*n),最坏情况为 1e8
优化:
我们可以发现上述划分集合的最后一维是可以省去的:
f[i][j][k] 表示: 最近的三位数,百位为 j ,十位为 k,个位为 l 的三素数个的个数,这里 l 省去了,
f[i][j][k]=(f[i][j][k]+f[i-1][k][l])%mod;
初始化可以不改变,也可以改为:f[2][j][k]=1;
优化前的代码:
#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<ctime>
#include<algorithm>
#include<utility>
#include<stack>
#include<queue>
#include<vector>
#include<set>
#include<math.h>
#include<map>
using namespace std;
typedef long long LL;
const int N = 1e4 + 3, M = 1e3,mod=1e9+9;
int n;
LL f[N][11][11][11];
int an[M];
vector<int>prime;
void init() {
an[1] = 1;
for (int i = 2; i < M; i++) {
if (an[i] == 0) {
prime.push_back(i);
}
for (int j = 0; j < prime.size() && prime[j] * i < M; j++) {
an[prime[j] * i] = 1;
}
}
}
int main() {
init();
cin >> n;
for (int j = 1; j <= 9; j++) {
for (int k = 0; k <= 9; k++) {
for (int l = 0; l <= 9; l++) {
f[3][j][k][l] = 1;
}
}
}
int t=0,tt=0;
for (int i = 4; i <= n; i++) {
for (int j = 1; j <= 9; j++) {
for (int k = 0; k <= 9; k++) {
for (int l = 0; l <= 9; l++) {
for (int p = 0; p <= 9; p++) {
t = j * 100 + k * 10 + l;
tt = k * 100 + l * 10 + p;
if (!an[t]&&!an[tt]) {
f[i][j][k][l] = (f[i - 1][k][l][p] + f[i][j][k][l])%mod;
}
}
}
}
}
}
LL ans = 0;
for (int j = 0; j <= 9; j++) {
for (int k = 0; k <= 9; k++) {
for (int l = 1; l <= 9; l++) {
t = j * 100 + k * 10 + l;
if (!an[t])
ans = (ans + f[n][j][k][l]) % mod;
}
}
}
cout << ans << endl;
return 0;
}
优化后的代码
#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<ctime>
#include<algorithm>
#include<utility>
#include<stack>
#include<queue>
#include<vector>
#include<set>
#include<math.h>
#include<map>
using namespace std;
typedef long long LL;
const int N = 1e4 + 3, M = 1e3, mod = 1e9 + 9;
int n;
LL f[N][11][11];
int an[M];
vector<int>prime;
void init() {
an[1] = 1;
for (int i = 2; i < M; i++) {
if (an[i] == 0) {
prime.push_back(i);
}
for (int j = 0; j < prime.size() && prime[j] * i < M; j++) {
an[prime[j] * i] = 1;
}
}
}
int main() {
init();
cin >> n;
for (int j = 0; j <= 9; j++) {
for (int k = 0; k <= 9; k++) {
f[2][j][k] = 1;
}
}
int t = 0, tt = 0;
for (int i = 3; i <= n; i++) {
for (int j = 1; j <= 9; j++) {
for (int k = 0; k <= 9; k++) {
for (int l = 0; l <= 9; l++) {
t = j * 100 + k * 10 + l;
if (!an[t]) {
f[i][j][k] = (f[i - 1][k][l] + f[i][j][k]) % mod;
}
}
}
}
}
LL ans = 0;
for (int j = 0; j <= 9; j++) {
for (int k = 0; k <= 9; k++) {
ans = (ans + f[n][j][k]) % mod;
}
}
cout << ans << endl;
return 0;
}