堆(优先队列)

堆是一种树形结构,树的根是堆顶,堆顶始终保持为所有元素的最优值。有大根堆和小根堆,大根堆的根节点是最大值,小根堆的根节点是最小值。堆一般用二叉树实现,称为二叉堆。

堆的存储方式

image

堆的操作

empty
返回堆是否为空

top
直接返回根节点的值,时间复杂度 \(O(1)\)

push
将新元素添加在数组最后面,若它比父节点小则不断与其父节点交换,使得堆重新满足父节点比子节点存储的数都要小(自下而上),时间复杂度 \(O(\log n)\)
image

pop
弹出根节点,并让堆依然符合原来的性质。首先交换根节点和数组中最后一个元素,再去掉最后一个元素。若新根节点比子节点大,则不断与较小子节点交换,直到重新满足条件(自上而下),时间复杂度 \(O(\log n)\)
image

例:P3378 【模板】堆

由此,给出二叉堆的模板实现:

参考代码
#include <cstdio>
#include <algorithm>
using namespace std;
const int MAXN = 1e6 + 5;
int heap[MAXN], len;
void push(int x) {
    heap[++len] = x;
    int i = len;
    while (i > 1 && heap[i] < heap[i / 2]) {
        swap(heap[i], heap[i / 2]);
        i /= 2;
    }
}
void pop() {
    heap[1] = heap[len--];
    int i = 1;
    while (i * 2 <= len) {
        int son = i * 2;
        if (son < len && heap[son + 1] < heap[son]) son++;
        if (heap[son] < heap[i]) {
            swap(heap[son], heap[i]);
            i = son;
        } else break;
    }
}
int main()
{
    int n;
    scanf("%d", &n);
    while (n--) {
        int op;
        scanf("%d", &op);
        if (op == 1) {
            int x;
            scanf("%d", &x);
            push(x);
        } else if (op == 2) printf("%d\n", heap[1]);
        else pop();
    }
    return 0;
}

例:P1177 【模板】排序

输入 \(n (n < 10^5)\) 个数字 \(a_i \ (a_i < 10^9)\),将其从小到大排序后输出。

分析:利用堆也是可以做排序的,先把所有的元素 push 进去,然后每次取出堆顶(最小值)输出并弹出堆顶,直到堆空为止,这种排序方法称为堆排序

参考代码
#include <cstdio>
#include <algorithm>
using namespace std;
const int MAXN=100005;
struct Heap {
    int a[MAXN],cnt;
    void push(int x) { // 压入
        a[++cnt]=x;
        int i=cnt;
        while (i>1 && a[i]<a[i/2]) {
            swap(a[i/2],a[i]);
            i/=2;
        }
    }
    void pop() { // 删除
        a[1]=a[cnt--];
        int i=1;
        while (i*2<=cnt) {
            int son=i*2;
            if (son<cnt && a[son+1]<a[son]) son++;
            if (a[son]<a[i]) {
                swap(a[son],a[i]);
                i=son;
            } else break;
        }
    }
    int top() {
        return a[1];
    }
};
Heap h;
int main()
{
    int n,x;
    scanf("%d",&n);
    for (int i=1;i<=n;i++) {
        scanf("%d",&x);
        h.push(x);
    }
    for (int i=1;i<=n;i++) {
        printf("%d ",h.top());
        h.pop();
    }
    return 0;
}

堆排序整体的时间复杂度是 \(O(n \log n)\),空间复杂度为 \(O(n)\)

优先队列

C++ 提供了优先队列这个数据结构,也就是 STL 中的 priority_queue,底层就是由堆实现的。要使用优先队列,需要包含 queue 头文件,优先队列支持的基础操作如下:

  1. priority_queue<int> q 新建一个保存 int 型变量的优先队列 q,默认是大根堆
  2. priority_queue<int, vector<int>, greater<int>> q 新建一个小根堆
  3. q.top() 优先队列查询最大值(或者是最小值)
  4. q.pop() 将最大值(最小值)弹出队列
  5. q.push(x)x 加入优先队列

和大多数 STL 容器一样,可以使用 q.empty() 判断它是否为空,用 q.size() 获取它的大小。

例:P3378 【模板】堆

用 STL 的优先队列来写这道题代码更加简洁。

// STL 优先队列
#include <cstdio>
#include <algorithm>
#include <queue>
using namespace std;
priority_queue<int, vector<int>, greater<int>> q; // 小根堆
int main()
{
    int n; scanf("%d", &n); // 操作次数
    while (n--) {
        int op, x; scanf("%d", &op);
        if (op == 1) { scanf("%d", &x); q.push(x); }
        else if (op == 2) printf("%d\n", q.top());
        else q.pop();
    }
    return 0;
}

例:P2168 [NOI2015] 荷马史诗

一部《荷马史诗》中有 \(n(n \le 10^6)\) 种不同的单词,从 \(1\)\(n\) 进行编号。其中第 \(i\) 种单词出现的总次数为 \(w_i(w_i \le 10^11)\)。现在要用 \(k\) 进制串 \(s_i\) 来替换第 \(i\) 种单词,使得其满足对于任意的 \(1 \le i,j \le n, i \ne j\),都有 \(s_i\) 不是 \(s_j\) 的前缀。请问如何选择 \(s_i\),才能使替换以后得到的新的《荷马史诗》长度最小。在确保总长度最小的情况下,还想知道最长的 \(s_i\) 的最短长度是多少?

解题思路

哈夫曼编码的变形。每次从堆中选出权重最小的 \(k\) 个结点,将其合并建边,然后放回堆中,直到建完哈夫曼树。例如,当各结点权重分别为 1、1、3、3、9、9,需要编码为三进制时,生成的哈夫曼树如下:

image

需要注意的是,每次合并都会减少 \(k-1\) 个结点,在合并最后一次的时候,如果可以合并的点的数量不足 \(k\) 个,靠近根结点的位置(短编码)反而没有被利用,所以需要在一开始补上 k-1-(n-1)%(k-1) 个权重为 \(0\) 的结点,把权重大的结点“推”到离根结点更近的位置。根据题目数据范围,答案需要 long long 类型。

参考代码
#include <cstdio>
#include <queue>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 100005;
LL w[N];
struct Node {
    LL val;
    int depth;
};
struct NodeCompare { // 定义Node比较类
    bool operator()(const Node &a, const Node &b) {
        // 权重相同时,高度小的优先出队
        return a.val != b.val ? a.val > b.val : a.depth > b.depth;
    }
};
int main()
{
    int n, k;
    scanf("%d%d", &n, &k);
    priority_queue<Node, vector<Node>, NodeCompare> q;
    for (int i = 1; i <= n; i++) {
        scanf("%lld", &w[i]);
        q.push({w[i], 1}); // 读入结点(叶节点)
    }
    if ((n - 1) % (k - 1) != 0) { // 有一次合并结点数量不足k个
        for (int i = 1; i <= k - 1 - (n - 1) % (k - 1); i++) 
            q.push({0, 1}); // 需要补若干个权重为0的结点 
    }
    LL ans = 0;
    while (q.size() != 1) {
        LL sum = 0; int maxh = 0;
        for (int i = 1; i <= k; i++) { // 从堆中取k个最小的
            Node tmp = q.top(); q.pop();
            sum += tmp.val; // 新结点加上子结点权重
            maxh = max(maxh, tmp.depth); // 最大深度
        }
        ans += sum; // 更新总长度
        q.push({sum, maxh + 1}); // 合并后的结点放回堆中
    }
    printf("%lld\n%lld\n", ans, q.top().depth - 1); // 编码长度是哈夫曼树的高度减1
    return 0;
}

例:P2085 最小函数值

题目给定了若干个二次函数,由于 \(x\) 取的都是正整数,并且三个系数都为正整数,因此函数的取值单调递增且肯定大于 \(0\),要求这些函数生成的所有函数值中最小的 \(m\) 个。

朴素想法

暴力计算每个函数值

朴素的想法是对于每个函数都计算前 \(m\) 个取值,这样会得到 \(n \times m\) 个函数值,最小的 \(m\) 个函数值一定在这个范围内,用一个最大容量限定为 \(m\) 的小根堆始终维护最小的 \(m\) 个函数值,时间复杂度 \(O(nm \log m)\)

优化思路

注意函数的取值是单调递增的,因此实际上可以看作是给定 \(n\) 个排好序的数组,只不过数组并没有真正地存下来,而是给出了下标和值的对应关系。对于每个数组,它们的最小值所在的下标都是 \(1\),假设每个数组都有一个箭头指向 \(1\),需要在所有箭头指向的函数值中找到最小的那个,接下来最小的那个所处的数组的箭头向后移动,指向 \(2\),然后再和其他箭头关联的函数值比较,以此类推。这样一来箭头的后移只需要执行 \(m\) 次即可,而找最小函数值这个过程可以利用一个小根堆来提高效率,总体时间复杂度 \(O(m \log n)\)

参考代码
#include <cstdio>
#include <queue>
#include <vector>
using namespace std;
const int N = 10005;
int a[N], b[N], c[N];
struct Node {
    int idx, x, f;
};
struct NodeCompare {
    bool operator()(const Node &lhs, const Node &rhs) const {
        return lhs.f > rhs.f;
    }
};
priority_queue<Node, vector<Node>, NodeCompare> q;
int fn(int idx, int x) {
    return a[idx] * x * x + b[idx] * x + c[idx];
}
int main()
{
    int n, m;
    scanf("%d%d", &n, &m);
    for (int i = 0; i < n; i++) {
        scanf("%d%d%d", &a[i], &b[i], &c[i]);
    }
    for (int i = 0; i < n; i++) q.push({i, 1, fn(i, 1)});
    for (int i = 0; i < m; i++) {
        Node t = q.top();
        q.pop();
        printf("%d ", t.f);
        q.push({t.idx, t.x + 1, fn(t.idx, t.x + 1)});
    }
    return 0;
}
#include <cstdio>
#include <queue>
#include <vector>
using namespace std;
const int N = 10005;
int a[N], b[N], c[N];
struct Node {
    int idx, x, f;
    bool operator<(const Node &other) const {
    	return f > other.f;
	}
};

priority_queue<Node> q;
int fn(int idx, int x) {
    return a[idx] * x * x + b[idx] * x + c[idx];
}
int main()
{
    int n, m;
    scanf("%d%d", &n, &m);
    for (int i = 0; i < n; i++) {
        scanf("%d%d%d", &a[i], &b[i], &c[i]);
    }
    for (int i = 0; i < n; i++) q.push({i, 1, fn(i, 1)});
    for (int i = 0; i < m; i++) {
        Node t = q.top();
        q.pop();
        printf("%d ", t.f);
        q.push({t.idx, t.x + 1, fn(t.idx, t.x + 1)});
    }
    return 0;
}

例:P1631 序列合并

解题思路

可以发现,最小和一定是 \(A[1]+B[1]\),次小和是 \(\min (A[1]+B[2],A[2]+B[1])\),假设次小和是 \(A[2]+B[1]\),那么第三小和就是 \(A[1]+B[2],A[2]+B[2],A[3]+B[1]\) 三者之一。也就是说,当确定 \(A[i]+B[j]\) 为第 \(k\) 小和后,\(A[i+1]+B[j]\)\(A[i]+B[j+1]\) 就加入了第 \(k+1\) 小和的备选答案集合。需要注意的是,\(A[1]+B[2]\)\(A[2]+B[1]\) 都能产生 \(A[2]+B[2]\) 这个备选答案。

考虑到这一点,我们不妨把 \(A\)\(B\) 两个序列的和看成 \(N\) 个有序数组,其中第一个数组为 \(A[1]+B[...]\),第二个数组为 \(A[2]+B[...]\),以此类推。这样一来,就相当于将这 \(N\) 个有序数组合并取出前 \(N\) 小的。因此可以先将 \(A[1]+B[1], A[2]+B[1], ..., A[N]+B[1]\)\(N\) 种情况先加入堆中,若取出的堆顶元素来自于第 \(K\) 个数组,则将 \(A[K]+B[2]\) 这种情况继续放入堆中,直到取够前 \(N\) 种情况。时间复杂度 \(O(N \log N)\)

参考代码
#include <cstdio>
#include <algorithm>
#include <queue>
using namespace std;
typedef long long LL;
const int N = 100005;
int a[N], b[N], ans[N];
struct Index {
    int x, y;
};
struct IndexCompare {
    bool operator()(const Index& idx1, const Index& idx2) const {
        return a[idx1.x] + b[idx1.y] > a[idx2.x] + b[idx2.y];
    }
};
int main()
{
    int n; scanf("%d", &n);
    for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
    for (int i = 1; i <= n; i++) scanf("%d", &b[i]);
    priority_queue<Index, vector<Index>, IndexCompare> q;
    for (int i = 1; i <= n; i++) q.push({i, 1});
    for (int i = 1; i <= n; i++) {
        Index tmp = q.top(); q.pop();
        ans[i] = a[tmp.x] + b[tmp.y];
        q.push({tmp.x, tmp.y + 1});
    }
    for (int i = 1; i <= n; i++) printf("%d%c", ans[i], i == n ? '\n' : ' ');
    return 0;
}

对顶堆

如果把大根堆想成一个上宽下窄的三角形,把小根堆想成一个上窄下宽的三角形,那么对顶堆就可以具体地被想象成一个“陀螺”或者一个“沙漏”,通过这两个堆的上下组合,我们可以把一组数据分别加入到对顶堆中的大根堆和小根堆,以维护我们不同的需要。

根据数学中不等式的传递原理,假如一个集合 A 中的最小元素比另一个集合 B 中的最大元素还要大,那么就可以断定: A 中的所有元素都比 B 中元素大。所以,我们把小根堆“放在”大根堆“上面”,如果小根堆的堆顶元素比大根堆的堆顶元素大,那么小根堆的所有元素要比大根堆的所有元素大。

例如给定 \(N\) 个数字,求其前 \(i\) 个元素中第 \(K\) 小的那个元素(\(K\) 值可变)。

我们可以这样解决问题:把大根堆的元素个数限制成 \(K\) 个,由大根堆维护前 \(K\) 小的元素(包含第 \(K\) 个),小根堆维护比第 \(K\) 小的元素还要大的元素。

  1. 插入:若插入的元素小于大根堆堆顶元素,则将其插入大根堆,否则将其插入小根堆
  2. 维护:当大根堆的大小大于 \(K\) 时,不断将大根堆堆顶元素取出并插入小根堆,直到大根堆的大小等于 \(K\);当大根堆的大小小于 \(K\) 时,不断将小根堆堆顶元素取出并插入大根堆,直到大根堆的大小等于 \(K\)
  3. 查询第 \(K\) 小的元素:大根堆堆顶元素
  4. 删除第 \(K\) 小的元素:删除大根堆堆顶元素

同理,对顶堆还可以用于解决其他“第 \(K\) 小”的变形问题:比如求前 \(i\) 个元素的中位数等。

例:P1168 中位数

解题思路

使用两个堆,大根堆维护较小的数,小根堆维护较大的数。这样一来,小根堆的堆顶是较大的数中最小的,大根堆的堆顶是较小的数中最大的。

而求中位数只需要在保证两个堆中元素大小关系的同时,控制两个堆的大小尽可能平衡,这样其中一个堆的堆顶元素即为中位数。

参考代码
#include <cstdio>
#include <queue>
#include <vector>
using namespace std;
int main()
{
    int n;
    scanf("%d", &n);
    priority_queue<int> big;
    priority_queue<int, vector<int>, greater<int>> small;
    for (int i = 1; i <= n; i++) {
        int x;
        scanf("%d", &x);
        small.push(x);
        if (i % 2 == 1) {
            while (!big.empty() && small.top() < big.top()) {
                int st = small.top();
                small.pop();
                int bt = big.top();
                big.pop();
                small.push(bt);
                big.push(st);
            }
            int st = small.top();
            small.pop();
            big.push(st);
            printf("%d\n", big.top());
        }
    }
    return 0;
}

例:P1801 黑匣子

解题思路

控制对顶堆中的大根堆的元素数目伴随着 \(i\) 的增长而增长。

参考代码
#include <cstdio>
#include <queue>
#include <iostream>
using namespace std;
const int N = 200005;
int a[N];
priority_queue<int, vector<int>, greater<int>> h;
priority_queue<int> ans;
int main()
{
    int m, n;
    scanf("%d%d", &m, &n);
    for (int i = 1; i <= m; i++) scanf("%d", &a[i]);
    int pre = 0;
    int idx = 0;
    while (n--) {
        int u;
        scanf("%d", &u);
        for (int i = pre + 1; i <= u; i++) 
            ans.push(a[i]);
        while (ans.size() > idx) {
            h.push(ans.top());
            ans.pop();
        }
        ans.push(h.top());
        h.pop();
        printf("%d\n", ans.top());
        pre = u;
        idx++;
    }
    return 0;
}
posted @ 2024-02-03 19:58  RonChen  阅读(71)  评论(0编辑  收藏  举报