CF1103D Professional layer
我们上手能有个思路:
先求出所有\(a_i\)的\(gcd\),再分解质因数,得到的素数个数\(m\)一定满足\(m\leq11\)。
之后只需要关注这些素数即可。若对一个数进行操作,一定是直接消除某些素数,所以至多选\(m\)个数。
所以只要根据\(k\)处理出每个数可以消除的素数集合,就得到一个\(O(n*m*3^m)\)的\(dp\),显然过不了的。
所以我们要在这个基础上进行剪枝。
剪枝:
将不是目标素数的素数全除掉得到一组新的\(a_i\),然后把相同\(a_i\)压缩起来,对应的所有\(c_i\)只保留最便宜的\(m\)个,那么至多只剩下\(O(M ∗ m)\)个物品(不同的 \(a_i\)只有\(M\)个)。
\(M\)取最大时,一定是取最小的\(m\)个素数,因为\(a_i\leq10^{12}\) ,可以通过程序暴力得到:\(M<12000\)
\(O(M*m^2*3^m)\) 仍然不能接受
因为最终至多只选\(m\)个,大部分被选中的机会很少甚至根本不会被选,这由它们每个管辖的素数集合和花费有关。可以预处理\(ok[S]\)表示素数集合为\(S\)时那些数可能被选,而对于每个\(S\),至多存储\(m\)个最便宜的数即可。然后按数字顺序做\(dp\)。
预处理\(ok[S]\) ,此处复杂度是 \(O(M*2^m*m^2)\)
这样所有的转移方向至多只有\(O(m*2^m)\) ,再枚举当前选了几个数,总的转移复杂度是 \(O(m^2∗3^m)\)
#define B cout << "BreakPoint" << endl;
#define O(x) cout << #x << " " << x << endl;
#define O_(x) cout << #x << " " << x << " ";
#define Msz(x) cout << "Sizeof " << #x << " " << sizeof(x)/1024/1024 << " MB" << endl;
#include<cstdio>
#include<cmath>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<queue>
#include<map>
#define fr first
#define sc second
#define LL long long
using namespace std;
inline LL read() {
LL s = 0,w = 1;
char ch = getchar();
while(ch < '0' || ch > '9') {
if(ch == '-')
w = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9') {
s = s * 10 + ch - '0';
ch = getchar();
}
return s * w;
}
LL gcd(LL a,LL b){return b ? gcd(b,a % b) : a;}
#define st first
#define nd second
#define pb push_back
const int N = 1e6 + 7;
struct node{
LL a,w;
bool operator <(const node a)const{return w < a.w;}
}c[N];
vector<LL>ok;
LL n,m,K,res,f[12][1 << 11],g[1 << 11];
map<LL,vector<LL>>mp;
void init(){
n = read(),K = read();
for(int i = 1;i <= n;i++){
c[i].a = read();
res = gcd(res,c[i].a);
}
for(int i = 1;i <= n;i++) c[i].w = read();
if(res == 1) puts("0"),exit(0);
for(int i = 2;i <= 1e6 && 1ll * i * i <= res;i++)
if(res % i == 0) {
ok.pb(i);
while(res % i == 0) res /= i;
}
if(res - 1) ok.pb(res);
m = ok.size();
for(int i = 1;i <= n;i++){
LL z = 1;
for(auto j : ok) while(c[i].a % j == 0) c[i].a /= j,z *= j;
mp[z].pb(c[i].w);
}
}
void solve(){
memset(f,2,sizeof(f));
f[0][0] = 0;
for(auto t : mp){
LL x = t.st;
sort(t.nd.begin(),t.nd.end());
if(t.nd.size() > m) t.nd.resize(m);
vector<LL>a;
for(int i = 0;i < 1<<m;i++){
LL times = 1 , p = x;
for(int j = 0;j < m;j++)
if(i & (1 << j)) while(p % ok[j] == 0) p /= ok[j],times *= ok[j];
a.pb(times);
}
for(auto p : t.nd){
bool flg = 0;
for(int i = m - 1;i >= 0;i--)
for(int j = 0;j < 1 << m;j++)
if(f[i][j] <= 1e12){
int tmp = (1 << m) - 1 - j;
for(int k = tmp;k;k = (k - 1) & tmp)
if(a[k] <= K) if(f[i + 1][j | k] > f[i][j] + p) f[i + 1][j | k] = f[i][j] + p,flg = 1;
}
if(!flg) break;
}
}
LL ans = 1e18;
for(int i = 1;i <= m;i++) ans = min(ans,f[i][(1 << m) - 1] * i);
if(ans > 1e12) puts("-1");
else printf("%lld\n",ans);
}
int main(){
init();
solve();
return 0;
}