kuangbin专题十五:数位DP
CodeForces55D Beautiful numbers
思路:经典题。考虑1至9的最小公倍数,还有前缀数中出现过的数能否整除。
#include<iostream> #include<vector> #include<cstring> using namespace std; typedef long long ll; const int maxn=3e3; const int LCM = 2520; vector<int> ary; int lcm2id[LCM + 10]; ll dp[20][LCM + 10][50]; ll gcd(int x, int y) { return x == 0 ? y : gcd(y % x, x); } ll lcm(int x, int y) { return x * y / gcd(x, y); } void init() { int num = 0; for (int i = 1;i <= LCM;++i) if (LCM % i == 0) lcm2id[i] = num++; memset(dp, -1, sizeof(dp)); } ll dfs(int pos, int r, int l, int op) { if (!pos) return r % l == 0; if (!op && ~dp[pos][r][lcm2id[l]]) return dp[pos][r][lcm2id[l]]; int maxx=op ? ary[pos] : 9; ll res = 0; for (int i = 0;i <= maxx; ++i){ if(i == 0) res += dfs(pos - 1, (r * 10 + i) % LCM, l, op & (ary[pos] == i)); else res += dfs(pos - 1, (r * 10 + i) % LCM, lcm(l, i), op & (ary[pos] == i)); } if (!op) dp[pos][r][lcm2id[l]] = res; return res; } ll solve(ll n) { ary.clear(); ary.push_back(-1); while (n) { ary.push_back(n%10); n /= 10; } return dfs(ary.size() - 1, 0, 1, 1); } int main() { init(); int T; cin >> T; while (T--) { ll l, r; cin >> l >> r; cout << solve(r) - solve(l - 1) << endl; } return 0; }
思路:三维dp。提交不了,未验证是否能a
#include<iostream> #include<vector> #include<cstring> using namespace std; typedef long long ll; vector<int> ary; ll dp[70][10][11]; ll dfs(int pos, int pre, int maxk, int preMink, int targetk, int lzero, int st) { if(!pos) return preMink == 0; if(!st && ~dp[pos][pre][maxk]) return dp[pos][pre][maxk]; ll maxx = st ? ary[pos] : 9, res = 0; for(int i = 0; i <= maxx; i++) { if(i <= pre){ if(i == 0 && lzero) res += dfs(pos - 1, i, targetk, preMink, targetk, lzero, st & (i == maxx)); else res += dfs(pos - 1, i, targetk - 1, min(preMink, maxk), targetk, 0, st & (i == maxx)); } else { res += dfs(pos - 1, i, maxk - 1, min(preMink, maxk - 1), targetk, 0, st & (i == maxx)); } } if(!st) dp[pos][pre][maxk] = res; return res; } ll solve(ll x, int k) { ary.clear(); ary.push_back(-1); while(x) { ary.push_back(x % 10); x /= 10; } return dfs(ary.size() - 1, 0, k, k, k, 1, 1); } int main() { memset(dp, -1, sizeof(dp)); int T; scanf("%d", &T); for(int t = 1; t <= T; t++){ ll l, r, k; scanf("%lld%lld%lld", &l, &r, &k); printf("Case #%d: %lld\n", t, solve(r, k) - solve(l-1, k)); } return 0; }
思路:数位DP。
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #include<vector> using namespace std; const int maxn = 12; int dp[maxn][15]; vector<int> v; int dfs(int pos, int st, int sig){ if(!pos) return st != 4; if(!sig && ~dp[pos][st]) return dp[pos][st]; int maxx = sig ? v[pos] : 9, res = 0; for(int i = 0; i <= maxx; i++){ if(i == 4) continue; if(st == 6 && i == 2) continue; if(st == 10 && i == 0) res += dfs(pos-1, 10, sig&(i==maxx)); else res += dfs(pos-1, i, sig&(i==maxx)); } if(!sig) dp[pos][st] = res; return res; } int solve(int n){ memset(dp, -1, sizeof(dp)); v.clear(); v.push_back(-1); while(n){ v.push_back(n % 10); n /= 10; } return dfs(v.size()-1, 10, 1); } int main(){ int n, m; while(cin >> n >> m && m){ cout << solve(m) - solve(n-1) << endl; } return 0; }
思路:注意数据范围。
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #include<vector> using namespace std; const int maxn = 20; unsigned long long dp[maxn][15]; int v[maxn]; unsigned long long dfs(int pos, int st, int sig){ if(pos < 0) return 1; if(!sig && ~dp[pos][st]) return dp[pos][st]; int maxx = sig ? v[pos] : 9; unsigned long long res = 0; for(int i = 0; i <= maxx; i++){ if(st == 4 && i == 9) continue; if(st == 10 && i == 0) res += dfs(pos-1, 10, sig & (i==maxx)); else res += dfs(pos-1, i, sig & (i==maxx)); } if(!sig) dp[pos][st] = res; return res; } unsigned long long solve(unsigned long long n){ int len = 0; while(n){ v[len++] = n % 10; n /= 10; } return dfs(len-1, 10, 1); } int main(){ memset(dp, -1, sizeof(dp)); unsigned long long T, n; cin >> T; while(T--){ cin >> n; cout << n + 1 - solve(n) << endl; } return 0; }
#include<iostream> #include<cstring> #include<algorithm> #include<vector> using namespace std; const int maxn = 36; int dp[maxn][maxn][maxn]; vector<int> ary; int dfs(int pos, int zero, int one, int lzero, int st) { if(pos == 0) return zero >= one; if(!st && ~dp[pos][zero][one]) return dp[pos][zero][one]; int maxx = st ? ary[pos] : 1, res = 0; for(int i = 0; i <= maxx; i++) { if(i == 0 && lzero) res += dfs(pos - 1, 0, 0, 1, st & (i == maxx)); else res += dfs(pos - 1, zero + (i == 0), one + (i == 1), 0, st & (i == maxx)); } if(!st) dp[pos][zero][one] = res; return res; } int solve(int n) { ary.clear(); ary.push_back(-1); while(n) { ary.push_back(n % 2); n /= 2; } return dfs(ary.size() - 1, 0, 0, 1, 1); } int main() { int a, b; memset(dp, -1, sizeof(dp)); cin >> a >> b; cout << solve(b) - solve(a - 1) << endl; return 0; }
思路:选取一个位置为pivot。
View Code
思路:普通数位dp。
#include<iostream> #include<vector> #include<cstring> using namespace std; typedef long long ll; vector<int> ary; ll dp[12][10][13][2]; ll dfs(int pos, int pre, int r, bool existB, int lzero, int op) { if(!pos) return existB && r == 0; if(!op && ~dp[pos][pre][r][existB]) return dp[pos][pre][r][existB]; ll maxx = op ? ary[pos] : 9, res = 0; for(int i = 0; i <= maxx; i++) { if(pre == 1 && i == 3) res += dfs(pos - 1, i, (r * 10 + i) % 13, true, 0, op & (i == maxx)); else if(i == 0 && lzero) res += dfs(pos - 1, i, 0 , existB, 1, op & (i == maxx)); else res += dfs(pos - 1, i, (r * 10 + i) % 13, existB, 0, op & (i == maxx)); } if(!op) dp[pos][pre][r][existB] = res; return res; } ll solve(ll n) { ary.clear(); ary.push_back(-1); while(n) { ary.push_back(n % 10); n /= 10; } return dfs(ary.size() - 1, 0, 0, false, 1, 1); } int main() { memset(dp, -1, sizeof(dp)); ll n; while(cin >> n) cout << solve(n) << endl; return 0; }
思路:强行数位DP。
#include<iostream> #include<vector> #include<cstring> using namespace std; typedef long long ll; int tot = 0; vector<int> ary; int dp[10][5000]; int F(int x){ int res = 0, w = 1; while(x) { res += (x % 10) * w; x /= 10; w <<= 1; } return res; } ll dfs(int pos, int sum, int op) { if(!pos) return sum <= tot; if(sum > tot) return 0; if(!op && ~dp[pos][tot - sum]) return dp[pos][tot - sum]; int maxx = op ? ary[pos] : 9, res = 0; for(int i = 0; i <= maxx; i++){ res += dfs(pos -1 , sum + i * (1 << (pos - 1)), op & (i == maxx)); } if(!op) dp[pos][tot - sum] = res; return res; } int solve(int x) { ary.clear(); ary.push_back(-1); while(x) { ary.push_back(x % 10); x /= 10; } return dfs(ary.size() - 1, 0, 1); } int main() { memset(dp, -1, sizeof(dp)); int T; scanf("%d", &T); for(int t = 1; t <= T; t++){ ll a, b; scanf("%lld%lld", &a, &b); tot = F(a); printf("Case #%d: %d\n", t, solve(b)); } return 0; }
#include<iostream> #include<vector> #include<cstring> using namespace std; typedef long long ll; const ll mod = 1e9 + 7; vector<int> ary; int dp[20][2][7][7]; // have 7 or not, bit sum, tot sum int dfs(int pos, int num, int have7, int s, int r, int lzero, int op) { if(!pos) { return have7 || (s % 7 == 0) || (r % 7 == 0) ? 0 : ((num * num) % mod); } if(!op && ~dp[pos][have7][s][r]) return dp[pos][have7][s][r]; int maxx = op ? ary[pos] : 9, res = 0; for(int i = 0; i <= maxx; i++){ if(i == 0 && lzero) { res += dfs(pos - 1, 0, 0, 0, 0, 1, op & (i == maxx)) % mod; res %= mod; } else{ res += dfs(pos - 1, num * 10 + i, have7 || (i == 7), (s + i) % 7, (r * 10 + i) % 7, 0, op & (i == maxx)) % mod; res %= mod; } } if(!op) dp[pos][have7][s][r] = res; return res; } int solve(int n) { ary.clear(); ary.push_back(-1); while(n) { ary.push_back(n % 10); n /= 10; } return dfs(ary.size() - 1, 0, 0, 0, 0, 1, 1); } int main() { memset(dp, -1, sizeof(dp)); int t; cin >> t; while(t--) { ll l, r; cin >> l >> r; cout << solve(r) - solve(l - 1) << endl; } return 0; }
SPOJ - BALNUM Balanced Numbers
思路:状态压缩。
#include<iostream> #include<vector> #include<cstring> using namespace std; typedef long long ll; vector<int> ary; ll dp[22][60000]; bool check(int st) { for(int i = 0; i <= 9; i++) { int val = st % 3; st /= 3; if(val == 0) continue; else if((i + val) % 2 == 0) return false; } return true; } int change(int st, int num) { int newst = 0, w = 1; for(int i = 0; i <= 9; i++) { int val = st % 3; st /= 3; if(num == i) { if(val == 1) val = 2; else val = 1; } newst += val * w; w *= 3; } return newst; } ll dfs(int pos, int st, int lzero, int op) { if(!pos) return check(st); if(!op && ~dp[pos][st]) return dp[pos][st]; ll maxx = op ? ary[pos] : 9, res = 0; for(int i = 0; i <= maxx; i++) { if(lzero && i == 0) res += dfs(pos - 1, st, lzero, op & (i == maxx)); else res += dfs(pos - 1, change(st, i), lzero & (i == 0), op & (i == maxx)); } if(!op) dp[pos][st] = res; return res; } int solve(int n) { ary.clear(); ary.push_back(-1); while(n) { ary.push_back(n % 10); n /= 10; } return dfs(ary.size() - 1, 0, 1, 1); } int main() { memset(dp, -1, sizeof(dp)); int t; cin >> t; while(t--) { ll l, r; cin >> l >> r; cout << solve(r) - solve(l - 1) << endl; } return 0; }