题解 [LNOI2022] 吃
考虑将一些乘法改成加法
那么就是求一个最大的 \(\frac{1+\sum b+i}{\prod a_i}\)
先将 \(a_i=1\) 的全都选了
剩下的发现分母是指数级的,所以最多选 log 个
那么每次贪心选能最大化分式的值的就好了
复杂度 \(O(n\log n)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 500010
#define fir first
#define sec second
#define ll long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n;
ll a[N], b[N];
const ll mod=1e9+7;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
namespace force{
int rec;
double ans;
void solve() {
int lim=1<<n;
for (int s=0; s<lim; ++s) {
double now=1;
for (int i=1; i<=n; ++i) if (s&(1<<i-1)) now+=b[i];
for (int i=1; i<=n; ++i) if (!(s&(1<<i-1))) now*=a[i];
if (now>ans) rec=s;
ans=max(ans, now);
}
cout<<(ll)fmod(ans, mod)<<endl;
// for (int i=1; i<=n; ++i) if (rec&(1<<i-1)) cout<<i<<' '; cout<<endl;
}
}
namespace task1{
ll ans=1;
pair<ll, ll> sta[N];
inline ll gcd(ll a, ll b) {return !b?a:gcd(b, a%b);}
void solve() {
random_device seed;
mt19937 rand(seed());
for (int i=1; i<=n; ++i) sta[i]={a[i], b[i]};
sort(sta+1, sta+n+1, [](pair<ll, ll> a, pair<ll, ll> b){
pair<ll, ll> t1={a.fir/gcd(a.fir, a.sec), a.sec/gcd(a.fir, a.sec)};
pair<ll, ll> t2={b.fir/gcd(b.fir, b.sec), b.sec/gcd(b.fir, b.sec)};
if (t1==t2) return a.sec<b.sec;
else return (double)a.sec/(double)a.fir>(double)b.sec/(double)b.fir;
});
// sort(sta+1, sta+n+1, [](pair<ll, ll> a, pair<ll, ll> b){return a.fir==b.fir?a.sec>b.sec:a.fir<b.fir;});
cout<<"sta: "; for (int i=1; i<=n; ++i) cout<<"("<<sta[i].fir<<','<<sta[i].sec<<") "; cout<<endl;
pair<double, ll> sumb={1, 1}, proda={1, 1};
for (int i=1; i<=n; ++i) ans=ans*sta[i].fir%mod;
for (int i=1; i<=n; ++i) if (sta[i].fir<=sta[i].sec) {
if ((sumb.fir+sta[i].sec)/(proda.fir*sta[i].fir) > sumb.fir/proda.fir) {
sumb.fir+=sta[i].sec, sumb.sec=(sumb.sec+sta[i].sec)%mod;
proda.fir*=sta[i].fir, proda.sec=proda.sec*sta[i].fir%mod;
cout<<"("<<sta[i].fir<<','<<sta[i].sec<<") ";
}
}
// cout<<endl;
printf("%lld\n", ans*qpow(proda.sec, mod-2)%mod*sumb.sec%mod);
}
}
namespace task2{
int top;
ll ans=1;
pair<ll, ll> sta[N];
inline ll gcd(ll a, ll b) {return !b?a:gcd(b, a%b);}
void solve() {
pair<long double, ll> sumb={1, 1}, proda={1, 1};
for (int i=1; i<=n; ++i) ans=ans*a[i]%mod;
for (int i=1; i<=n; ++i)
if (a[i]==1) sumb.fir+=b[i], sumb.sec=(sumb.sec+b[i])%mod;
else sta[++top]={a[i], b[i]};
// cout<<"sta: "; for (int i=1; i<=n; ++i) cout<<"("<<sta[i].fir<<','<<sta[i].sec<<") "; cout<<endl;
while (1) {
int maxi=0;
pair<long double, ll> maxb=sumb, maxa=proda;
for (int i=1; i<=top; ++i) if (sta[i].fir<=sta[i].sec) {
if ((sumb.fir+sta[i].sec)/(proda.fir*sta[i].fir) > maxb.fir/maxa.fir) {
maxb.fir=sumb.fir+sta[i].sec, maxb.sec=(sumb.sec+sta[i].sec)%mod;
maxa.fir=proda.fir*sta[i].fir, maxa.sec=proda.sec*sta[i].fir%mod;
maxi=i;
// cout<<"("<<sta[i].fir<<','<<sta[i].sec<<") ";
}
}
if (!maxi) break;
else swap(sta[maxi], sta[top--]), sumb=maxb, proda=maxa; //, cout<<"once"<<endl, cout<<maxi<<endl;
}
// cout<<endl;
printf("%lld\n", ans*qpow(proda.sec, mod-2)%mod*sumb.sec%mod);
}
}
signed main()
{
// freopen("food.in", "r", stdin);
// freopen("food.out", "w", stdout);
n=read();
for (int i=1; i<=n; ++i) a[i]=read();
for (int i=1; i<=n; ++i) b[i]=read();
// force::solve();
// task1::solve();
task2::solve();
return 0;
}