洛谷P4571 [JSOI2009] 瓶子和燃料
题目
https://www.luogu.com.cn/problem/P4571
思路
首先观察并且简单模拟一下火星人取燃料的过程,发现最终燃料的量一定是他选的k个瓶子容量的线性组合(观察操作3就知道)。火星人的抠门,就是说他会找到这些线性组合当中最小的正整数结果\(y\)。
让我们用式子描述一下,设k个瓶子的容量分别为\(x_1,x_2,...x_k\),那么有不定方程\(a_1x_1+a_2x_2+...+a_kx_k=y\),我们要求的就是使方程有解的最小的\(y\)。
然后根据裴蜀定理,\(y_{min}=gcd(x_1,x_2,...,x_k)\),题目就变成了要选k个数使得他们gcd值最大。
先不考虑时间复杂度,看看怎么保证正确性,首先gcd一定是某个\(a[i]\)的因数,那我们就对每个\(a[i]\)求一遍因数,然后该因数所对应的计数器++(我们先不考虑怎样计数)。最后查一下计数器\(\geq k\)的值。
考虑下总共有多少因数,一个数的除数函数\(\sigma_0(a)\)有个不紧的上界\(2\sqrt a\),总共1000个数,每个数因数个数小于1e5(事实上肯定远远小于),所以我们写个哈希再开个数组计数就行了。
然后就是时间复杂度的问题。求一个数的因数要实打实的\(\sqrt a\),总计是\(10^8\)级别,就很容易挂,得想办法优化一下。
①对于质数,它只有本身是有贡献的,根本不用继续试除,我们写个miller_rabin把素数判掉。
②维护一个全局ans值保存当前最大合法gcd值,如果在试除过程中不可能再有因子大于ans了就跳出。
③预先将a数组大到小排序,可能有助于提前找到很大的最优解。
加了这几个优化之后跑得飞快(287ms)。
代码
点击查看代码
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cmath>
#define ll long long
using namespace std;
int prime[10]={0,2,3,5,7,11};
int n,k,ans=0;
int a[1010];
const int zzd=233333;
struct Hashmap{
int fst[zzd],nxt[1000000],cnt,val[1000000];
Hashmap(){cnt=0;}
int insert(int x){
nxt[++cnt]=fst[x%zzd];
val[cnt]=x;
fst[x%zzd]=cnt;
return cnt;
}
int find(int y){
int i=fst[y%zzd];
while(i){
if(val[i]==y) return i;
i=nxt[i];
}
return 0;
}
} H;
int c[1000000];
ll qpow(ll x,ll p,ll m){
ll ans=1,base=x;
for(;p;p>>=1){
if(p&1) ans=ans*base%m;
base=base*base%m;
}
return ans;
}
int miller_rabin(ll x){
if(x==1) return 0;
ll y=x-1;
int i,j,u=0;
while(!(y&1ll)) y>>=1,u++;
for(i=1;i<=5;++i){
if(x==prime[i]) return 1;
if(!(x%prime[i])) return 0;
ll q=qpow(prime[i],y,x);
for(j=0;j<=u;++j){
if((q*q)%x==1&&q!=1&&q!=x-1) return 0;
if(j==u&&q!=1) return 0;
q=q*q%x;
}
}
return 1;
}
bool cmp(ll x,ll y){
return x>y;
}
void dec(int y){
int i,z;
int x=y,lim=(int)sqrt(x)+1;
for(i=1;i<=lim;++i){
if(x%i) continue;
if(x/i<=ans) break;
z=H.find(i);
if(!z) z=H.insert(i);
if(++c[z]>=k) ans=max(ans,i);
z=H.find(x/i);
if(!z) z=H.insert(x/i);
if(++c[z]>=k) ans=max(ans,x/i);
if(i==1&&miller_rabin(x)) break;
}
return;
}
int main(){
int i,j;
scanf("%d%d",&n,&k);
for(i=1;i<=n;++i) scanf("%d",&a[i]);
sort(a+1,a+n+1,cmp);
for(i=1;i<=n;++i){
if(a[i]<=ans) break;
dec(a[i]);
}
printf("%d",ans);
// system("pause");
return 0;
}