多路归并

Posted on 2023-04-01 17:19  lyc2002  阅读(16)  评论(0编辑  收藏  举报

能解决什么问题

一般是给出 n 个递减的等差数列,要求对于所有等差数列中前 m 个大的数的和

时间复杂度

O(m * logn)

[acwing]1262. 鱼塘钓鱼

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>

using namespace std;

typedef pair<int, int> PII;

const int N = 110;

int n;
int a[N], d[N], t[N]; // t[i] 表示从第 1 个池塘走到第 i 个池塘所需时间
int T;
int spend[N]; // spend[i] 表示在第 i 个池塘钓了多长时间
int res;

// 当前能钓得的鱼数
int get(int i)
{
    return max(0, a[i] - d[i] * spend[i]);
}

int solve(int n, int t)
{
    if (t <= 0) return 0;
    
    memset(spend, 0, sizeof(spend));
    
    int res = 0;
    priority_queue<PII> h;
    for (int i = 1; i <= n; i++) h.push({get(i), i});
    while (t-- > 0) {
        auto tmp = h.top();
        if (tmp.first == 0) break;
        h.pop();
        res += tmp.first;
        spend[tmp.second]++;
        h.push({get(tmp.second), tmp.second});
    }
    
    return res;
}

int main()
{
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
    for (int i = 1; i <= n; i++) scanf("%d", &d[i]);
    for (int i = 2; i <= n; i++) {
        scanf("%d", &t[i]);
        t[i] += t[i - 1];
    }
    scanf("%d", &T);
    
    for (int i = 1; i <= n; i++) 
        res = max(res, solve(i, T - t[i]));
    
    printf("%d", res);
    
    return 0;
}

[acwing]4656. 技能升级

/*
    此题数据范围过大,多路归并做法只能过 6/12
*/
#include <cstdio>
#include <queue>
#include <cstring>

using namespace std;

typedef long long LL;
typedef pair<int, int> PII;

const int N = 100010;

int n, m;
int a[N], d[N], cnt[N];

int get(int i)
{
    return max(0, a[i] - d[i] * cnt[i]);
}

LL solve(int m)
{
    memset(cnt, 0, sizeof(cnt));
    
    LL res = 0;
    priority_queue<PII> h;
    
    for (int i = 1; i <= n; i++) h.push({get(i), i});
    
    while (m--) {
        auto t = h.top();
        h.pop();
        if (t.first == 0) break;
        res += t.first;
        cnt[t.second]++;
        h.push({get(t.second), t.second});
    }
    
    return res;
}

int main()
{
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) scanf("%d%d", &a[i], &d[i]);
    
    printf("%lld", solve(m));
    
    return 0;
}
/*
    优化做法 O(n * log 1e6)
    首先,对于所有等差数列从大到小排序,找排在第 m 个的数其值为多少,记为 x
    枚举每个等差数列,只加上大于 x 的项,最后补上等于 x 的项
    项数:c = (a[i] - x) / d[i] + 1
*/
#include <cstdio>

using namespace std;

typedef long long LL;

const int N = 100010;

int n, m;
int a[N], d[N];

bool check(int mid)
{
    LL res = 0;
    for (int i = 0; i < n; i++) {
        if (a[i] >= mid) {
            res += (a[i] - mid) / d[i] + 1;
        }
    }
    return res >= m;
}

int main()
{
    scanf("%d%d", &n, &m);
    for (int i = 0; i < n; i++) scanf("%d%d", &a[i], &d[i]);
    
    // 左边界要从 0 开始,因为第 m 个数的值可能为 0
    int l = 0, r = 1e6;
    while (l < r) {
        int mid = l + r + 1 >> 1;
        if (check(mid)) l = mid;
        else r = mid - 1;
    }
    
    LL res = 0, cnt = 0;
    for (int i = 0; i < n; i++) {
        if (a[i] > r) {
            int c = (a[i] - r) / d[i] + 1;
            int end = a[i] - d[i] * (c - 1);
            res += ((LL)a[i] + end) * c / 2;
            cnt += c;
        }
    }
    
    printf("%lld", res + (m - cnt) * r);
    
    return 0;
}