BZOJ2757 - 淘金(数位dp,多路归并)
题目
小Z在玩一个叫做《淘金者》的游戏。游戏的世界是一个二维坐标。X轴、Y轴坐标范围均为1..N。初始的时候,所有的整数坐标点上均有一块金子,共N*N块。
一阵风吹过,金子的位置发生了一些变化。细心的小Z发现,初始在(i,j)坐标处的金子会变到(f(i),f(j))坐标处。其中f(x)表示x各位数字的乘积,例如f(99)=81,f(12)=2,f(10)=0。如果金子变化后的坐标不在1..N的范围内,我们认为这块金子已经被移出游戏。同时可以发现,对于变化之后的游戏局面,某些坐标上的金子数量可能不止一块,而另外一些坐标上可能已经没有金子。这次变化之后,游戏将不会再对金子的位置和数量进行改变,玩家可以开始进行采集工作。
小Z很懒,打算只进行K次采集。每次采集可以得到某一个坐标上的所有金子,采集之后,该坐标上的金子数变为0。
现在小Z希望知道,对于变化之后的游戏局面,在采集次数为K的前提下,最多可以采集到多少块金子?
答案可能很大,小Z希望得到对1000000007(10^9+7)取模之后的答案。
题解
这里两个坐标是互相独立的,主要求出一维的情况,二维的答案就是相乘的结果。(牛客2020多校有道数位dp就是两个互相关联的数之间的数位dp)。
突破口在于f的值域大于0的情况非常少,题目范围下只有不到1万个数。所以可以先预处理出这些数,然后离散一下。分别求出每个乘积是多少个数的乘积(该位置下有多少金块),假设位cnt数组。
最后问题转变成求两个有序cnt数组,分别取出一个数相乘的前k大个数。把这排序后看成k个有序序列,然后优先队列多路归并取前k个数即可。
#include <bits/stdc++.h>
#define endl '\n'
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define FILE freopen(".//data_generator//in.txt","r",stdin),freopen("res.txt","w",stdout)
#define FI freopen(".//data_generator//in.txt","r",stdin)
#define FO freopen("res.txt","w",stdout)
#define pb push_back
#define mp make_pair
#define seteps(N) fixed << setprecision(N)
typedef long long ll;
using namespace std;
/*-----------------------------------------------------------------*/
ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
#define INF 0x3f3f3f3f
const int N = 1e5 + 10;
const int M = 1e9 + 7;
const double eps = 1e-5;
set<ll> num[2][10];
vector<ll> pd;
ll f[20][N];
ll cnt[N];
bool vis[20][N];
int di[N];
int get(ll v) {
int p = lower_bound(pd.begin(), pd.end(), v) - pd.begin() + 1;
assert(pd[p - 1] == v);
return p;
}
ll dfs(int p, ll st, int lmt, int lead) {
if(!p) {
return (!lead) && (st == 1);
}
if(!lead && !lmt && f[p][get(st)] >= 0) return f[p][get(st)];
ll res = 0;
int maxx = lmt ? di[p] : 9;
for(int i = 0; i <= maxx; i++) {
if(lead) {
if(!i) res += dfs(p - 1, st, lmt && i == maxx, i == 0 && lead);
else if(st % i == 0) res += dfs(p - 1, st / i, lmt && i == maxx, i == 0 && lead);
} else if(i && st % i == 0)
res += dfs(p - 1, st / i, lmt && i == maxx, i == 0 && lead);
}
if(!lmt && !lead) f[p][get(st)] = res;
return res;
}
void solve(ll x) {
int tot = 0;
while(x) {
di[++tot] = x % 10;
x /= 10;
}
for(int i = 0; i < pd.size(); i++) {
cnt[i] = dfs(tot, pd[i], 1, 1);
cnt[i];
}
}
void init() {
for(int i = 1; i <= 9; i++) {
num[1][i].insert(i);
}
for(int i = 2; i <= 12; i++) {
for(int j = 1; j <= 9; j++) {
for(int k = 1; k <= j; k++) {
for(auto v : num[!(i % 2)][k]) {
num[i % 2][j].insert(v * j);
}
}
}
}
set<ll> tmp;
for(int i = 1; i <= 9; i++) {
for(auto v : num[0][i]) {
tmp.insert(v);
}
for(auto v : num[1][i]) {
tmp.insert(v);
}
}
for(auto v : tmp) pd.push_back(v);
}
struct node {
ll res;
int p1, p2;
bool operator < (const node& a) const {
return res < a.res;
}
};
priority_queue<node> q;
int main() {
IOS;
memset(f, -1, sizeof f);
init();
ll n;
int k;
cin >> n >> k;
solve(n);
vector<ll> ans;
sort(cnt, cnt + pd.size(), greater<ll>());
int len = min(k, (int)pd.size());
for(int i = 0; i < len; i++) {
q.push(node{cnt[i] * cnt[0], i, 0});
}
ll res = 0;
for(int i = 0; i < k; i++) {
auto mi = q.top();
q.pop();
res += mi.res;
res %= M;
int p1 = mi.p1;
int p2 = mi.p2 + 1;
q.push(node{cnt[p1] * cnt[p2], p1, p2});
}
cout << res << endl;
}