Solution -「AGC 034C」Tests
\(\mathcal{Description}\)
Link.
给定非负整数序列 \(\{l_n\},\{r_n\},\{b_n\},X\),求最小的 \(s\),使得存在非负整数序列 \(\{a_n\},\{c_n\}\),满足 \(a_i\le X\),\(\sum_{i=1}^na_i=s\),\(c_i\in[l_i,r_i]\),且
所有输入均 \(\le10^5\)。
\(\mathcal{Solution}\)
显然二分 \(s\),仅需做到检测某个 \(s\) 是否合法。下令 \(w=\sum_{i=1}^nc_ia_i\)。
假设 \(\{a_n\}\) 已经确定,那么所有满足 \(a_i\ge b_i\) 的 \(c_i=r_i\),其余 \(c_i=l_i\)。考虑初始时所有 \(a_i=0,c_i=l_i\),现在把 \(s\) 个 \(1\) 挨个加到一些 \(a_i\) 上。当 \(a_i<b_i\) 时,对 \(w\) 贡献 \(l_i\)(此时 \(c_i\) 仍取 \(l_i\));当 \(a_i=b_i\) 时,对 \(w\) 贡献由 \(l_i\) 转为 \(r_i\)(\(c_i\) 变成 \(r_i\));继续增加,对 \(w\) 贡献 \(r_i\)。最终仅需比较 \(w\) 和 \(\sum_{i=1}^nl_ib_i\) 的大小。
所以问题抽象为:有 \(n\) 个分段函数 \(f_{1..n}(x)\),满足
仅需钦定 \(\{x_n\}\),使得 \(\sum_{i=1}^nf_i(x_i)\) 取最大。
考虑贪心,不难证明:至多有一个 \(0<x_i<X\)。直接枚举哪一个 \(0<x_i<X\),贪心地选取最大的 \(f_i(X)\),即可 \(\mathcal O(n)\) 检测。最终复杂度 \(\mathcal O(n\log\sum_{i=1}^nb_i)\)。
\(\mathcal{Code}\)
/* Clearink */
#include <cstdio>
#include <algorithm>
#define rep( i, l, r ) for ( int i = l, repEnd##i = r; i <= repEnd##i; ++i )
#define per( i, r, l ) for ( int i = r, repEnd##i = l; i >= repEnd##i; --i )
inline int rint () {
int x = 0, f = 1; char s = getchar ();
for ( ; s < '0' || '9' < s; s = getchar () ) f = s == '-' ? -f : f;
for ( ; '0' <= s && s <= '9'; s = getchar () ) x = x * 10 + ( s ^ '0' );
return x * f;
}
typedef long long LL;
const int MAXN = 1e5;
int n, x, b[MAXN + 5], l[MAXN + 5], r[MAXN + 5], ord[MAXN + 5];
LL full[MAXN + 5];
inline LL contr ( const int i, const int s ) {
return s <= b[i] ? 1ll * s * l[i]
: 1ll * l[i] * b[i] + 1ll * ( s - b[i] ) * r[i];
}
inline void init () {
std::sort ( ord + 1, ord + n + 1, []( const int i, const int j ) {
return contr ( i, x ) > contr ( j, x );
} );
rep ( i, 1, n ) full[i] = full[i - 1] + contr ( ord[i], x );
}
inline LL calc ( const LL scr ) {
/*
* let's come up with a greedy algorithm!
* */
int fcnt = scr / x, rest = scr % x;
LL ret = 0;
rep ( i, 1, n ) { // score on exam <ord[i]> is <rest>.
LL cur = contr ( ord[i], rest );
if ( i > fcnt ) cur += full[fcnt];
else cur += full[fcnt + 1] - contr ( ord[i], x );
ret = cur > ret ? cur : ret;
}
return ret;
}
int main () {
n = rint (), x = rint ();
LL sum = 0, sb = 0;
rep ( i, 1, n ) {
ord[i] = i;
b[i] = rint (), l[i] = rint (), r[i] = rint ();
sum += 1ll * b[i] * l[i], sb += b[i];
}
init ();
LL lef = 0, rig = sb;
while ( lef < rig ) {
LL mid = lef + rig >> 1;
if ( calc ( mid ) >= sum ) rig = mid;
else lef = mid + 1;
}
printf ( "%lld\n", lef );
return 0;
}