能解决什么问题
一般是给出 n 个递减的等差数列,要求对于所有等差数列中前 m 个大的数的和
时间复杂度
O(m * logn)
[acwing]1262. 鱼塘钓鱼
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
using namespace std;
typedef pair<int, int> PII;
const int N = 110;
int n;
int a[N], d[N], t[N]; // t[i] 表示从第 1 个池塘走到第 i 个池塘所需时间
int T;
int spend[N]; // spend[i] 表示在第 i 个池塘钓了多长时间
int res;
// 当前能钓得的鱼数
int get(int i)
{
return max(0, a[i] - d[i] * spend[i]);
}
int solve(int n, int t)
{
if (t <= 0) return 0;
memset(spend, 0, sizeof(spend));
int res = 0;
priority_queue<PII> h;
for (int i = 1; i <= n; i++) h.push({get(i), i});
while (t-- > 0) {
auto tmp = h.top();
if (tmp.first == 0) break;
h.pop();
res += tmp.first;
spend[tmp.second]++;
h.push({get(tmp.second), tmp.second});
}
return res;
}
int main()
{
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
for (int i = 1; i <= n; i++) scanf("%d", &d[i]);
for (int i = 2; i <= n; i++) {
scanf("%d", &t[i]);
t[i] += t[i - 1];
}
scanf("%d", &T);
for (int i = 1; i <= n; i++)
res = max(res, solve(i, T - t[i]));
printf("%d", res);
return 0;
}
[acwing]4656. 技能升级
/*
此题数据范围过大,多路归并做法只能过 6/12
*/
#include <cstdio>
#include <queue>
#include <cstring>
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
const int N = 100010;
int n, m;
int a[N], d[N], cnt[N];
int get(int i)
{
return max(0, a[i] - d[i] * cnt[i]);
}
LL solve(int m)
{
memset(cnt, 0, sizeof(cnt));
LL res = 0;
priority_queue<PII> h;
for (int i = 1; i <= n; i++) h.push({get(i), i});
while (m--) {
auto t = h.top();
h.pop();
if (t.first == 0) break;
res += t.first;
cnt[t.second]++;
h.push({get(t.second), t.second});
}
return res;
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d%d", &a[i], &d[i]);
printf("%lld", solve(m));
return 0;
}
/*
优化做法 O(n * log 1e6)
首先,对于所有等差数列从大到小排序,找排在第 m 个的数其值为多少,记为 x
枚举每个等差数列,只加上大于 x 的项,最后补上等于 x 的项
项数:c = (a[i] - x) / d[i] + 1
*/
#include <cstdio>
using namespace std;
typedef long long LL;
const int N = 100010;
int n, m;
int a[N], d[N];
bool check(int mid)
{
LL res = 0;
for (int i = 0; i < n; i++) {
if (a[i] >= mid) {
res += (a[i] - mid) / d[i] + 1;
}
}
return res >= m;
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 0; i < n; i++) scanf("%d%d", &a[i], &d[i]);
// 左边界要从 0 开始,因为第 m 个数的值可能为 0
int l = 0, r = 1e6;
while (l < r) {
int mid = l + r + 1 >> 1;
if (check(mid)) l = mid;
else r = mid - 1;
}
LL res = 0, cnt = 0;
for (int i = 0; i < n; i++) {
if (a[i] > r) {
int c = (a[i] - r) / d[i] + 1;
int end = a[i] - d[i] * (c - 1);
res += ((LL)a[i] + end) * c / 2;
cnt += c;
}
}
printf("%lld", res + (m - cnt) * r);
return 0;
}