优先队列与TopK

一、简介

  前文介绍了《最大堆》的实现,本章节在最大堆的基础上实现一个简单的优先队列。优先队列的实现本身没什么难度,所以本文我们从优先队列的场景出发介绍topK问题。

 

  后面会持续更新数据结构相关的博文。

  数据结构专栏:https://www.cnblogs.com/hello-shf/category/1519192.html

  git传送门:https://github.com/hello-shf/data-structure.git

二、优先队列

  普通的队列是一种先进先出的数据结构,元素在队列尾追加,而从队列头删除。在优先队列中,元素被赋予优先级。当访问元素时,具有最高优先级的元素最先删除。优先队列具有最高级先出 (first in, largest out)的行为特征。通常采用堆数据结构来实现。

  上面是百度百科给出的优先队列的解释。解释的还是很到位的。具体的优先队列的实现可以采用最小堆或者最大堆。因为在我们前文《最大堆》的实现中,该堆存储的元素是要求实现Comparable接口的。所以优先级是掌握在用户手中的,所以最小堆和最大堆都可以作为优先队列的底层数据结构。

  普通的队列Queue,我们都知道是先进先出(FIFO)的,所以元素的出队顺序和入队顺序是保持一致的。但是对于我们的优先队列,出队操作,将不保证先进先出的队列特性,而是根据元素的优先级(或者说权重)决定出队的顺序。

  如果最小元素拥有最高的优先级,那么这种优先队列叫作升序优先队列,即总是优先删除最小的元素。同理,如果最大元素拥有最高的优先级,那么这种优先队列叫作降序优先队列,即总是先删除最大的元素。

  优先队列的使用场景:

 算法场景:
     最短路径算法:Dijkstra算法
     最小生成树算法:Prim算法
     事件驱动仿真:顾客排队算法
     选择问题:查找第k个最小元素
 实现场景:
     游戏中优先攻击最近单位,优先攻击血量最低等
     售票窗口的老幼病残孕和军人优先购票等

 

三、优先队列的实现

 

3.1、队列接口定义

  同普通的队列,我们先定义队列的接口如下

 1 /**
 2  * 描述:队列
 3  *
 4  * @Author shf
 5  * @Date 2019/7/18 15:30
 6  * @Version V1.0
 7  **/
 8 public interface Queue<E> {
 9     /**
10      * 获取当前队列的元素数
11      * @return
12      */
13     int getSize();
14 
15     /**
16      * 判断当前队列是否为空
17      * @return
18      */
19     boolean isEmpty();
20 
21     /**
22      * 入队操作
23      * @param e
24      */
25     void enqueue(E e);
26 
27     /**
28      * 出队操作
29      * @return
30      */
31     E dequeue();
32 
33     /**
34      * 获取队列头元素
35      * @return
36      */
37     E getFront();
38 }

  

3.2、最大堆实现的优先队列

  我们使用前文实现的《最大堆》来实现一个优先队列

 1 /**
 2  * 描述:优先队列
 3  *
 4  * @Author shf
 5  * @Date 2019/7/18 17:31
 6  * @Version V1.0
 7  **/
 8 public class PriorityQueue<E extends Comparable<E>> implements Queue<E> {
 9 
10     private MaxHeap<E> maxHeap;
11 
12     public PriorityQueue(){
13         maxHeap = new MaxHeap<>();
14     }
15 
16     @Override
17     public int getSize(){
18         return maxHeap.size();
19     }
20 
21     @Override
22     public boolean isEmpty(){
23         return maxHeap.isEmpty();
24     }
25 
26     @Override
27     public E getFront(){
28         // 获取队列的头元素,在最大堆中就是获取堆顶元素
29         return maxHeap.findMax();
30     }
31 
32     @Override
33     public void enqueue(E e){
34         // 压栈 直接向最大堆中添加,让最大堆的add方法维护 元素的优先级
35         maxHeap.add(e);
36     }
37 
38     @Override
39     public E dequeue(){
40         // 出栈 将最大堆的堆顶元素取出
41         return maxHeap.extractMax();
42     }
43 }

  需要解释的都在代码注释中了。

  到这里优先队列就实现完了,是不是很简单。

  在java中也有一个类PriorityQueue,其底层是采用的最小堆实现的优先队列。在java PriorityQueue中关于优先级的定义,优先级队列的元素按照其自然顺序进行排序,或者根据构造队列时提供的 Comparator 进行排序,具体取决于所使用的构造方法。底层数据结构最大堆或者最小堆是没有什么区别的。关键在于我们如何定义优先级。

 

四、topK问题

  关于topK问题,leetcode上面有一道典型的题目

  题目最终需要返回的是前 k 个频率最大的元素,可以想到借助堆这种数据结构,对于 k 频率之后的元素不用再去处理,进一步优化时间复杂度。

  具体操作为:

借助 哈希表 来建立数字和其出现次数的映射,遍历一遍数组统计元素的频率
维护一个元素数目为 k 的最小堆
每次都将新的元素与堆顶元素(堆中频率最小的元素)进行比较
如果新的元素的频率比堆顶端的元素大,则弹出堆顶端的元素,将新的元素添加进堆中
最终,堆中的 k 个元素即为前 k 个高频元素

 

   具体实现

 1 class Solution {
 2     public List<Integer> topKFrequent(int[] nums, int k) {
 3         // 使用字典,统计每个元素出现的次数,元素为键,元素出现的次数为值
 4         HashMap<Integer,Integer> map = new HashMap();
 5         for(int num : nums){
 6             if (map.containsKey(num)) {
 7                map.put(num, map.get(num) + 1);
 8              } else {
 9                 map.put(num, 1);
10              }
11         }
12         // 遍历map,用最小堆保存频率最大的k个元素
13         PriorityQueue<Integer> pq = new PriorityQueue<>(new Comparator<Integer>() {
14             @Override
15             public int compare(Integer a, Integer b) {
16                 return map.get(a) - map.get(b);
17             }
18         });
19         for (Integer key : map.keySet()) {
20             if (pq.size() < k) {
21                 pq.add(key);
22             } else if (map.get(key) > map.get(pq.peek())) {
23                 pq.remove();
24                 pq.add(key);
25             }
26         }
27         // 取出最小堆中的元素
28         List<Integer> res = new ArrayList<>();
29         while (!pq.isEmpty()) {
30             res.add(pq.remove());
31         }
32         return res;
33     }
34 }

 

   以上是使用java原生的优先队列实现的。接下来我们用我们自己实现的PriorityQueue试验一下。

   首先因为我们没有提供接收一个Comparator的构造器,所以我们通过定义一个类来完成这个过程比较。

  因为自己定义的优先队列底层使用的是我们自己实现的最大堆,以及最大堆底层数组也是使用自己定义的,所以我们在leetcode提交验证的时候,需要将这些自定义的类以内部类的方式提交上去。整体代码如下

 

  1 /// 347. Top K Frequent Elements
  2 /// https://leetcode.com/problems/top-k-frequent-elements/description/
  3 
  4 import java.util.LinkedList;
  5 import java.util.List;
  6 import java.util.TreeMap;
  7 
  8 class Solution {
  9 
 10     private class Array<E> {
 11 
 12         private E[] data;
 13         private int size;
 14 
 15         // 构造函数,传入数组的容量capacity构造Array
 16         public Array(int capacity){
 17             data = (E[])new Object[capacity];
 18             size = 0;
 19         }
 20 
 21         // 无参数的构造函数,默认数组的容量capacity=10
 22         public Array(){
 23             this(10);
 24         }
 25 
 26         public Array(E[] arr){
 27             data = (E[])new Object[arr.length];
 28             for(int i = 0 ; i < arr.length ; i ++)
 29                 data[i] = arr[i];
 30             size = arr.length;
 31         }
 32 
 33         // 获取数组的容量
 34         public int getCapacity(){
 35             return data.length;
 36         }
 37 
 38         // 获取数组中的元素个数
 39         public int getSize(){
 40             return size;
 41         }
 42 
 43         // 返回数组是否为空
 44         public boolean isEmpty(){
 45             return size == 0;
 46         }
 47 
 48         // 在index索引的位置插入一个新元素e
 49         public void add(int index, E e){
 50 
 51             if(index < 0 || index > size)
 52                 throw new IllegalArgumentException("Add failed. Require index >= 0 and index <= size.");
 53 
 54             if(size == data.length)
 55                 resize(2 * data.length);
 56 
 57             for(int i = size - 1; i >= index ; i --)
 58                 data[i + 1] = data[i];
 59 
 60             data[index] = e;
 61 
 62             size ++;
 63         }
 64 
 65         // 向所有元素后添加一个新元素
 66         public void addLast(E e){
 67             add(size, e);
 68         }
 69 
 70         // 在所有元素前添加一个新元素
 71         public void addFirst(E e){
 72             add(0, e);
 73         }
 74 
 75         // 获取index索引位置的元素
 76         public E get(int index){
 77             if(index < 0 || index >= size)
 78                 throw new IllegalArgumentException("Get failed. Index is illegal.");
 79             return data[index];
 80         }
 81 
 82         // 修改index索引位置的元素为e
 83         public void set(int index, E e){
 84             if(index < 0 || index >= size)
 85                 throw new IllegalArgumentException("Set failed. Index is illegal.");
 86             data[index] = e;
 87         }
 88 
 89         // 查找数组中是否有元素e
 90         public boolean contains(E e){
 91             for(int i = 0 ; i < size ; i ++){
 92                 if(data[i].equals(e))
 93                     return true;
 94             }
 95             return false;
 96         }
 97 
 98         // 查找数组中元素e所在的索引,如果不存在元素e,则返回-1
 99         public int find(E e){
100             for(int i = 0 ; i < size ; i ++){
101                 if(data[i].equals(e))
102                     return i;
103             }
104             return -1;
105         }
106 
107         // 从数组中删除index位置的元素, 返回删除的元素
108         public E remove(int index){
109             if(index < 0 || index >= size)
110                 throw new IllegalArgumentException("Remove failed. Index is illegal.");
111 
112             E ret = data[index];
113             for(int i = index + 1 ; i < size ; i ++)
114                 data[i - 1] = data[i];
115             size --;
116             data[size] = null; // loitering objects != memory leak
117 
118             if(size == data.length / 4 && data.length / 2 != 0)
119                 resize(data.length / 2);
120             return ret;
121         }
122 
123         // 从数组中删除第一个元素, 返回删除的元素
124         public E removeFirst(){
125             return remove(0);
126         }
127 
128         // 从数组中删除最后一个元素, 返回删除的元素
129         public E removeLast(){
130             return remove(size - 1);
131         }
132 
133         // 从数组中删除元素e
134         public void removeElement(E e){
135             int index = find(e);
136             if(index != -1)
137                 remove(index);
138         }
139 
140         public void swap(int i, int j){
141 
142             if(i < 0 || i >= size || j < 0 || j >= size)
143                 throw new IllegalArgumentException("Index is illegal.");
144 
145             E t = data[i];
146             data[i] = data[j];
147             data[j] = t;
148         }
149 
150         @Override
151         public String toString(){
152 
153             StringBuilder res = new StringBuilder();
154             res.append(String.format("Array: size = %d , capacity = %d\n", size, data.length));
155             res.append('[');
156             for(int i = 0 ; i < size ; i ++){
157                 res.append(data[i]);
158                 if(i != size - 1)
159                     res.append(", ");
160             }
161             res.append(']');
162             return res.toString();
163         }
164 
165         // 将数组空间的容量变成newCapacity大小
166         private void resize(int newCapacity){
167 
168             E[] newData = (E[])new Object[newCapacity];
169             for(int i = 0 ; i < size ; i ++)
170                 newData[i] = data[i];
171             data = newData;
172         }
173     }
174 
175     private class MaxHeap<E extends Comparable<E>> {
176 
177         private Array<E> data;
178 
179         public MaxHeap(int capacity){
180             data = new Array<>(capacity);
181         }
182 
183         public MaxHeap(){
184             data = new Array<>();
185         }
186 
187         public MaxHeap(E[] arr){
188             data = new Array<>(arr);
189             for(int i = parent(arr.length - 1) ; i >= 0 ; i --)
190                 siftDown(i);
191         }
192 
193         // 返回堆中的元素个数
194         public int size(){
195             return data.getSize();
196         }
197 
198         // 返回一个布尔值, 表示堆中是否为空
199         public boolean isEmpty(){
200             return data.isEmpty();
201         }
202 
203         // 返回完全二叉树的数组表示中,一个索引所表示的元素的父亲节点的索引
204         private int parent(int index){
205             if(index == 0)
206                 throw new IllegalArgumentException("index-0 doesn't have parent.");
207             return (index - 1) / 2;
208         }
209 
210         // 返回完全二叉树的数组表示中,一个索引所表示的元素的左孩子节点的索引
211         private int leftChild(int index){
212             return index * 2 + 1;
213         }
214 
215         // 返回完全二叉树的数组表示中,一个索引所表示的元素的右孩子节点的索引
216         private int rightChild(int index){
217             return index * 2 + 2;
218         }
219 
220         // 向堆中添加元素
221         public void add(E e){
222             data.addLast(e);
223             siftUp(data.getSize() - 1);
224         }
225 
226         private void siftUp(int k){
227 
228             while(k > 0 && data.get(parent(k)).compareTo(data.get(k)) < 0 ){
229                 data.swap(k, parent(k));
230                 k = parent(k);
231             }
232         }
233 
234         // 看堆中的最大元素
235         public E findMax(){
236             if(data.getSize() == 0)
237                 throw new IllegalArgumentException("Can not findMax when heap is empty.");
238             return data.get(0);
239         }
240 
241         // 取出堆中最大元素
242         public E extractMax(){
243 
244             E ret = findMax();
245 
246             data.swap(0, data.getSize() - 1);
247             data.removeLast();
248             siftDown(0);
249 
250             return ret;
251         }
252 
253         private void siftDown(int k){
254 
255             while(leftChild(k) < data.getSize()){
256                 int j = leftChild(k); // 在此轮循环中,data[k]和data[j]交换位置
257                 if( j + 1 < data.getSize() &&
258                         data.get(j + 1).compareTo(data.get(j)) > 0 )
259                     j ++;
260                 // data[j] 是 leftChild 和 rightChild 中的最大值
261 
262                 if(data.get(k).compareTo(data.get(j)) >= 0 )
263                     break;
264 
265                 data.swap(k, j);
266                 k = j;
267             }
268         }
269 
270         // 取出堆中的最大元素,并且替换成元素e
271         public E replace(E e){
272 
273             E ret = findMax();
274             data.set(0, e);
275             siftDown(0);
276             return ret;
277         }
278     }
279 
280     private interface Queue<E> {
281 
282         int getSize();
283         boolean isEmpty();
284         void enqueue(E e);
285         E dequeue();
286         E getFront();
287     }
288 
289     private class PriorityQueue<E extends Comparable<E>> implements Queue<E> {
290 
291         private MaxHeap<E> maxHeap;
292 
293         public PriorityQueue(){
294             maxHeap = new MaxHeap<>();
295         }
296 
297         @Override
298         public int getSize(){
299             return maxHeap.size();
300         }
301 
302         @Override
303         public boolean isEmpty(){
304             return maxHeap.isEmpty();
305         }
306 
307         @Override
308         public E getFront(){
309             return maxHeap.findMax();
310         }
311 
312         @Override
313         public void enqueue(E e){
314             maxHeap.add(e);
315         }
316 
317         @Override
318         public E dequeue(){
319             return maxHeap.extractMax();
320         }
321     }
322 
323     private class Freq implements Comparable<Freq>{
324 
325         public int e, freq;
326 
327         public Freq(int e, int freq){
328             this.e = e;
329             this.freq = freq;
330         }
331 
332         @Override
333         public int compareTo(Freq another){
334             if(this.freq < another.freq)
335                 return 1;
336             else if(this.freq > another.freq)
337                 return -1;
338             else
339                 return 0;
340         }
341     }
342 
343     public List<Integer> topKFrequent(int[] nums, int k) {
344 
345         TreeMap<Integer, Integer> map = new TreeMap<>();
346         for(int num: nums){
347             if(map.containsKey(num))
348                 map.put(num, map.get(num) + 1);
349             else
350                 map.put(num, 1);
351         }
352 
353         PriorityQueue<Freq> pq = new PriorityQueue<>();
354         for(int key: map.keySet()){
355             if(pq.getSize() < k)
356                 pq.enqueue(new Freq(key, map.get(key)));
357             else if(map.get(key) > pq.getFront().freq){
358                 pq.dequeue();
359                 pq.enqueue(new Freq(key, map.get(key)));
360             }
361         }
362 
363         LinkedList<Integer> res = new LinkedList<>();
364         while(!pq.isEmpty())
365             res.add(pq.dequeue().e);
366         return res;
367     }
368 
369     private static void printList(List<Integer> nums){
370         for(Integer num: nums)
371             System.out.print(num + " ");
372         System.out.println();
373     }
374 
375     public static void main(String[] args) {
376 
377         int[] nums = {1, 1, 1, 2, 2, 3};
378         int k = 2;
379         printList((new Solution()).topKFrequent(nums, k));
380     }
381 }
View Code

 

   在以上代码中我们需要关心的是如下部分

 1   private class Freq implements Comparable<Freq>{
 2 
 3         public int e, freq;
 4 
 5         public Freq(int e, int freq){
 6             this.e = e;
 7             this.freq = freq;
 8         }
 9 
10         @Override
11         public int compareTo(Freq another){
12             if(this.freq < another.freq)
13                 return 1;
14             else if(this.freq > another.freq)
15                 return -1;
16             else
17                 return 0;
18         }
19     }
20 
21     public List<Integer> topKFrequent(int[] nums, int k) {
22 
23         TreeMap<Integer, Integer> map = new TreeMap<>();
24         for(int num: nums){
25             if(map.containsKey(num))
26                 map.put(num, map.get(num) + 1);
27             else
28                 map.put(num, 1);
29         }
30 
31         PriorityQueue<Freq> pq = new PriorityQueue<>();
32         for(int key: map.keySet()){
33             if(pq.getSize() < k)
34                 pq.enqueue(new Freq(key, map.get(key)));
35             else if(map.get(key) > pq.getFront().freq){
36                 pq.dequeue();
37                 pq.enqueue(new Freq(key, map.get(key)));
38             }
39         }
40 
41         LinkedList<Integer> res = new LinkedList<>();
42         while(!pq.isEmpty())
43             res.add(pq.dequeue().e);
44         return res;
45     }

 

   我们将完整代码提交到leetcode

  得到如下结果,表示我们验证自己实现的优先队列成功了。

 

 

 

 

  这盛世,如您所愿。

  如有错误的地方还请留言指正。

  原创不易,转载请注明原文地址:https://www.cnblogs.com/hello-shf/p/11397386.html 

 

posted @ 2019-09-05 12:00  超级小小黑  阅读(1933)  评论(2编辑  收藏  举报