找丑数
额,刚开始比较naive,直接用stl大法,结果T了
#pragma GCC optimize(2) #include<cstdio> #include<iostream> #include<queue> #include<set> #include<vector> #define ll long long using namespace std; ll n,m,val,zhi,tot,ce[123123]; priority_queue <ll , vector <ll> , greater<ll> > q; set <ll> s; int main(){ scanf("%lld%lld",&n,&m); for(int i=1;i<=n;i++) { scanf("%lld",&val); ce[i] = val; s.insert(val); q.push(val); } while(1) { val = q.top(); q.pop(); tot++; if(tot == m) { cout << val; return 0; } for(int i=1;i<=n;i++) { zhi = val * ce[i]; if(!s.count(zhi)) { s.insert(zhi); q.push(zhi); } } } }
然后把优先队列改成手写堆,set该做map,继续T。。。QAQ。。。
#pragma GCC optimize(2) #include<cstdio> #include<map> #define ll long long using namespace std; ll n,m,val,zhi,tot,ce[1231203]; ll dui[1001000]; map <ll,int> s; inline void insert(ll val) { dui[++tot] = val; ll k = tot; while(k > 1 && dui[k] < dui[k >> 1]) { swap(dui[k] , dui[k >> 1]); k >>= 1; } } inline void del() { dui[1] = dui[tot--]; bool k; ll p = 1; while(p * 2 <= tot) { if(p * 2 == tot) k = 0; else k = dui[p << 1] > dui[p << 1 | 1] ? 1 : 0; if(dui[p] > dui[p * 2 + k]) { swap(dui[p] , dui[p * 2 + k]); p = p * 2 + k; } else break; } } int main(){ scanf("%lld%lld",&n,&m); for(int i=1;i<=n;++i) { scanf("%lld",&val); ce[i] = val; insert(val); s[val] = 1; } for(int j=1;j <= m;++j) { if(j == m) { printf("%lld",dui[1]); return 0; } val = dui[1]; del(); for(int i=1;i<=n;++i) { zhi = val * ce[i]; if(!s.count(zhi)) { insert(zhi); s[zhi] = 1; } } } }
最后做了一些操作就可以过了,不过很妙啊,看看代码应该就懂了吧。
#include<cstdio> #include<algorithm> #define ll long long using namespace std; ll n,m,val,zhi,tot,ce[131203]; struct st{ ll val; int ji; }dui[8001000]; inline void insert(st t) { dui[++tot] = t; ll k = tot; while(k > 1 && dui[k].val < dui[k >> 1].val) { swap(dui[k] , dui[k >> 1]); k >>= 1; } } inline void del() { dui[1] = dui[tot--]; bool k; ll p = 1; while(p * 2 <= tot) { if(p * 2 == tot) k = 0; else k = dui[p << 1].val > dui[p << 1 | 1].val ? 1 : 0; if(dui[p].val > dui[p * 2 + k].val) { swap(dui[p] , dui[p * 2 + k]); p = p * 2 + k; } else break; } } int main(){ scanf("%lld%lld",&n,&m); for(int i=1;i<=n;++i) { scanf("%lld",&val); ce[i] = val; st a; a.val = val; a.ji = i; insert(a); } for(int j = 1;j <= m;++j) { if(j == m) { printf("%lld",dui[1].val); return 0; } val = dui[1].val; int ka = dui[1].ji; del(); for(int i=ka;i<=n;++i) { zhi = val * ce[i]; st a; a.val = zhi; a.ji = i; insert(a); } } }