优先队列与TopK
一、简介
前文介绍了《最大堆》的实现,本章节在最大堆的基础上实现一个简单的优先队列。优先队列的实现本身没什么难度,所以本文我们从优先队列的场景出发介绍topK问题。
后面会持续更新数据结构相关的博文。
数据结构专栏:https://www.cnblogs.com/hello-shf/category/1519192.html
git传送门:https://github.com/hello-shf/data-structure.git
二、优先队列
普通的队列是一种先进先出的数据结构,元素在队列尾追加,而从队列头删除。在优先队列中,元素被赋予优先级。当访问元素时,具有最高优先级的元素最先删除。优先队列具有最高级先出 (first in, largest out)的行为特征。通常采用堆数据结构来实现。
上面是百度百科给出的优先队列的解释。解释的还是很到位的。具体的优先队列的实现可以采用最小堆或者最大堆。因为在我们前文《最大堆》的实现中,该堆存储的元素是要求实现Comparable接口的。所以优先级是掌握在用户手中的,所以最小堆和最大堆都可以作为优先队列的底层数据结构。
普通的队列Queue,我们都知道是先进先出(FIFO)的,所以元素的出队顺序和入队顺序是保持一致的。但是对于我们的优先队列,出队操作,将不保证先进先出的队列特性,而是根据元素的优先级(或者说权重)决定出队的顺序。
如果最小元素拥有最高的优先级,那么这种优先队列叫作升序优先队列,即总是优先删除最小的元素。同理,如果最大元素拥有最高的优先级,那么这种优先队列叫作降序优先队列,即总是先删除最大的元素。
优先队列的使用场景:
算法场景:
最短路径算法:Dijkstra算法
最小生成树算法:Prim算法
事件驱动仿真:顾客排队算法
选择问题:查找第k个最小元素
实现场景:
游戏中优先攻击最近单位,优先攻击血量最低等
售票窗口的老幼病残孕和军人优先购票等
三、优先队列的实现
3.1、队列接口定义
同普通的队列,我们先定义队列的接口如下
1 /** 2 * 描述:队列 3 * 4 * @Author shf 5 * @Date 2019/7/18 15:30 6 * @Version V1.0 7 **/ 8 public interface Queue<E> { 9 /** 10 * 获取当前队列的元素数 11 * @return 12 */ 13 int getSize(); 14 15 /** 16 * 判断当前队列是否为空 17 * @return 18 */ 19 boolean isEmpty(); 20 21 /** 22 * 入队操作 23 * @param e 24 */ 25 void enqueue(E e); 26 27 /** 28 * 出队操作 29 * @return 30 */ 31 E dequeue(); 32 33 /** 34 * 获取队列头元素 35 * @return 36 */ 37 E getFront(); 38 }
3.2、最大堆实现的优先队列
我们使用前文实现的《最大堆》来实现一个优先队列
1 /** 2 * 描述:优先队列 3 * 4 * @Author shf 5 * @Date 2019/7/18 17:31 6 * @Version V1.0 7 **/ 8 public class PriorityQueue<E extends Comparable<E>> implements Queue<E> { 9 10 private MaxHeap<E> maxHeap; 11 12 public PriorityQueue(){ 13 maxHeap = new MaxHeap<>(); 14 } 15 16 @Override 17 public int getSize(){ 18 return maxHeap.size(); 19 } 20 21 @Override 22 public boolean isEmpty(){ 23 return maxHeap.isEmpty(); 24 } 25 26 @Override 27 public E getFront(){ 28 // 获取队列的头元素,在最大堆中就是获取堆顶元素 29 return maxHeap.findMax(); 30 } 31 32 @Override 33 public void enqueue(E e){ 34 // 压栈 直接向最大堆中添加,让最大堆的add方法维护 元素的优先级 35 maxHeap.add(e); 36 } 37 38 @Override 39 public E dequeue(){ 40 // 出栈 将最大堆的堆顶元素取出 41 return maxHeap.extractMax(); 42 } 43 }
需要解释的都在代码注释中了。
到这里优先队列就实现完了,是不是很简单。
在java中也有一个类PriorityQueue,其底层是采用的最小堆实现的优先队列。在java PriorityQueue中关于优先级的定义,优先级队列的元素按照其自然顺序进行排序,或者根据构造队列时提供的 Comparator 进行排序,具体取决于所使用的构造方法。底层数据结构最大堆或者最小堆是没有什么区别的。关键在于我们如何定义优先级。
四、topK问题
关于topK问题,leetcode上面有一道典型的题目
题目最终需要返回的是前 k 个频率最大的元素,可以想到借助堆这种数据结构,对于 k 频率之后的元素不用再去处理,进一步优化时间复杂度。
具体操作为:
借助 哈希表 来建立数字和其出现次数的映射,遍历一遍数组统计元素的频率
维护一个元素数目为 k 的最小堆
每次都将新的元素与堆顶元素(堆中频率最小的元素)进行比较
如果新的元素的频率比堆顶端的元素大,则弹出堆顶端的元素,将新的元素添加进堆中
最终,堆中的 k 个元素即为前 k 个高频元素
具体实现
1 class Solution { 2 public List<Integer> topKFrequent(int[] nums, int k) { 3 // 使用字典,统计每个元素出现的次数,元素为键,元素出现的次数为值 4 HashMap<Integer,Integer> map = new HashMap(); 5 for(int num : nums){ 6 if (map.containsKey(num)) { 7 map.put(num, map.get(num) + 1); 8 } else { 9 map.put(num, 1); 10 } 11 } 12 // 遍历map,用最小堆保存频率最大的k个元素 13 PriorityQueue<Integer> pq = new PriorityQueue<>(new Comparator<Integer>() { 14 @Override 15 public int compare(Integer a, Integer b) { 16 return map.get(a) - map.get(b); 17 } 18 }); 19 for (Integer key : map.keySet()) { 20 if (pq.size() < k) { 21 pq.add(key); 22 } else if (map.get(key) > map.get(pq.peek())) { 23 pq.remove(); 24 pq.add(key); 25 } 26 } 27 // 取出最小堆中的元素 28 List<Integer> res = new ArrayList<>(); 29 while (!pq.isEmpty()) { 30 res.add(pq.remove()); 31 } 32 return res; 33 } 34 }
以上是使用java原生的优先队列实现的。接下来我们用我们自己实现的PriorityQueue试验一下。
首先因为我们没有提供接收一个Comparator的构造器,所以我们通过定义一个类来完成这个过程比较。
因为自己定义的优先队列底层使用的是我们自己实现的最大堆,以及最大堆底层数组也是使用自己定义的,所以我们在leetcode提交验证的时候,需要将这些自定义的类以内部类的方式提交上去。整体代码如下
1 /// 347. Top K Frequent Elements 2 /// https://leetcode.com/problems/top-k-frequent-elements/description/ 3 4 import java.util.LinkedList; 5 import java.util.List; 6 import java.util.TreeMap; 7 8 class Solution { 9 10 private class Array<E> { 11 12 private E[] data; 13 private int size; 14 15 // 构造函数,传入数组的容量capacity构造Array 16 public Array(int capacity){ 17 data = (E[])new Object[capacity]; 18 size = 0; 19 } 20 21 // 无参数的构造函数,默认数组的容量capacity=10 22 public Array(){ 23 this(10); 24 } 25 26 public Array(E[] arr){ 27 data = (E[])new Object[arr.length]; 28 for(int i = 0 ; i < arr.length ; i ++) 29 data[i] = arr[i]; 30 size = arr.length; 31 } 32 33 // 获取数组的容量 34 public int getCapacity(){ 35 return data.length; 36 } 37 38 // 获取数组中的元素个数 39 public int getSize(){ 40 return size; 41 } 42 43 // 返回数组是否为空 44 public boolean isEmpty(){ 45 return size == 0; 46 } 47 48 // 在index索引的位置插入一个新元素e 49 public void add(int index, E e){ 50 51 if(index < 0 || index > size) 52 throw new IllegalArgumentException("Add failed. Require index >= 0 and index <= size."); 53 54 if(size == data.length) 55 resize(2 * data.length); 56 57 for(int i = size - 1; i >= index ; i --) 58 data[i + 1] = data[i]; 59 60 data[index] = e; 61 62 size ++; 63 } 64 65 // 向所有元素后添加一个新元素 66 public void addLast(E e){ 67 add(size, e); 68 } 69 70 // 在所有元素前添加一个新元素 71 public void addFirst(E e){ 72 add(0, e); 73 } 74 75 // 获取index索引位置的元素 76 public E get(int index){ 77 if(index < 0 || index >= size) 78 throw new IllegalArgumentException("Get failed. Index is illegal."); 79 return data[index]; 80 } 81 82 // 修改index索引位置的元素为e 83 public void set(int index, E e){ 84 if(index < 0 || index >= size) 85 throw new IllegalArgumentException("Set failed. Index is illegal."); 86 data[index] = e; 87 } 88 89 // 查找数组中是否有元素e 90 public boolean contains(E e){ 91 for(int i = 0 ; i < size ; i ++){ 92 if(data[i].equals(e)) 93 return true; 94 } 95 return false; 96 } 97 98 // 查找数组中元素e所在的索引,如果不存在元素e,则返回-1 99 public int find(E e){ 100 for(int i = 0 ; i < size ; i ++){ 101 if(data[i].equals(e)) 102 return i; 103 } 104 return -1; 105 } 106 107 // 从数组中删除index位置的元素, 返回删除的元素 108 public E remove(int index){ 109 if(index < 0 || index >= size) 110 throw new IllegalArgumentException("Remove failed. Index is illegal."); 111 112 E ret = data[index]; 113 for(int i = index + 1 ; i < size ; i ++) 114 data[i - 1] = data[i]; 115 size --; 116 data[size] = null; // loitering objects != memory leak 117 118 if(size == data.length / 4 && data.length / 2 != 0) 119 resize(data.length / 2); 120 return ret; 121 } 122 123 // 从数组中删除第一个元素, 返回删除的元素 124 public E removeFirst(){ 125 return remove(0); 126 } 127 128 // 从数组中删除最后一个元素, 返回删除的元素 129 public E removeLast(){ 130 return remove(size - 1); 131 } 132 133 // 从数组中删除元素e 134 public void removeElement(E e){ 135 int index = find(e); 136 if(index != -1) 137 remove(index); 138 } 139 140 public void swap(int i, int j){ 141 142 if(i < 0 || i >= size || j < 0 || j >= size) 143 throw new IllegalArgumentException("Index is illegal."); 144 145 E t = data[i]; 146 data[i] = data[j]; 147 data[j] = t; 148 } 149 150 @Override 151 public String toString(){ 152 153 StringBuilder res = new StringBuilder(); 154 res.append(String.format("Array: size = %d , capacity = %d\n", size, data.length)); 155 res.append('['); 156 for(int i = 0 ; i < size ; i ++){ 157 res.append(data[i]); 158 if(i != size - 1) 159 res.append(", "); 160 } 161 res.append(']'); 162 return res.toString(); 163 } 164 165 // 将数组空间的容量变成newCapacity大小 166 private void resize(int newCapacity){ 167 168 E[] newData = (E[])new Object[newCapacity]; 169 for(int i = 0 ; i < size ; i ++) 170 newData[i] = data[i]; 171 data = newData; 172 } 173 } 174 175 private class MaxHeap<E extends Comparable<E>> { 176 177 private Array<E> data; 178 179 public MaxHeap(int capacity){ 180 data = new Array<>(capacity); 181 } 182 183 public MaxHeap(){ 184 data = new Array<>(); 185 } 186 187 public MaxHeap(E[] arr){ 188 data = new Array<>(arr); 189 for(int i = parent(arr.length - 1) ; i >= 0 ; i --) 190 siftDown(i); 191 } 192 193 // 返回堆中的元素个数 194 public int size(){ 195 return data.getSize(); 196 } 197 198 // 返回一个布尔值, 表示堆中是否为空 199 public boolean isEmpty(){ 200 return data.isEmpty(); 201 } 202 203 // 返回完全二叉树的数组表示中,一个索引所表示的元素的父亲节点的索引 204 private int parent(int index){ 205 if(index == 0) 206 throw new IllegalArgumentException("index-0 doesn't have parent."); 207 return (index - 1) / 2; 208 } 209 210 // 返回完全二叉树的数组表示中,一个索引所表示的元素的左孩子节点的索引 211 private int leftChild(int index){ 212 return index * 2 + 1; 213 } 214 215 // 返回完全二叉树的数组表示中,一个索引所表示的元素的右孩子节点的索引 216 private int rightChild(int index){ 217 return index * 2 + 2; 218 } 219 220 // 向堆中添加元素 221 public void add(E e){ 222 data.addLast(e); 223 siftUp(data.getSize() - 1); 224 } 225 226 private void siftUp(int k){ 227 228 while(k > 0 && data.get(parent(k)).compareTo(data.get(k)) < 0 ){ 229 data.swap(k, parent(k)); 230 k = parent(k); 231 } 232 } 233 234 // 看堆中的最大元素 235 public E findMax(){ 236 if(data.getSize() == 0) 237 throw new IllegalArgumentException("Can not findMax when heap is empty."); 238 return data.get(0); 239 } 240 241 // 取出堆中最大元素 242 public E extractMax(){ 243 244 E ret = findMax(); 245 246 data.swap(0, data.getSize() - 1); 247 data.removeLast(); 248 siftDown(0); 249 250 return ret; 251 } 252 253 private void siftDown(int k){ 254 255 while(leftChild(k) < data.getSize()){ 256 int j = leftChild(k); // 在此轮循环中,data[k]和data[j]交换位置 257 if( j + 1 < data.getSize() && 258 data.get(j + 1).compareTo(data.get(j)) > 0 ) 259 j ++; 260 // data[j] 是 leftChild 和 rightChild 中的最大值 261 262 if(data.get(k).compareTo(data.get(j)) >= 0 ) 263 break; 264 265 data.swap(k, j); 266 k = j; 267 } 268 } 269 270 // 取出堆中的最大元素,并且替换成元素e 271 public E replace(E e){ 272 273 E ret = findMax(); 274 data.set(0, e); 275 siftDown(0); 276 return ret; 277 } 278 } 279 280 private interface Queue<E> { 281 282 int getSize(); 283 boolean isEmpty(); 284 void enqueue(E e); 285 E dequeue(); 286 E getFront(); 287 } 288 289 private class PriorityQueue<E extends Comparable<E>> implements Queue<E> { 290 291 private MaxHeap<E> maxHeap; 292 293 public PriorityQueue(){ 294 maxHeap = new MaxHeap<>(); 295 } 296 297 @Override 298 public int getSize(){ 299 return maxHeap.size(); 300 } 301 302 @Override 303 public boolean isEmpty(){ 304 return maxHeap.isEmpty(); 305 } 306 307 @Override 308 public E getFront(){ 309 return maxHeap.findMax(); 310 } 311 312 @Override 313 public void enqueue(E e){ 314 maxHeap.add(e); 315 } 316 317 @Override 318 public E dequeue(){ 319 return maxHeap.extractMax(); 320 } 321 } 322 323 private class Freq implements Comparable<Freq>{ 324 325 public int e, freq; 326 327 public Freq(int e, int freq){ 328 this.e = e; 329 this.freq = freq; 330 } 331 332 @Override 333 public int compareTo(Freq another){ 334 if(this.freq < another.freq) 335 return 1; 336 else if(this.freq > another.freq) 337 return -1; 338 else 339 return 0; 340 } 341 } 342 343 public List<Integer> topKFrequent(int[] nums, int k) { 344 345 TreeMap<Integer, Integer> map = new TreeMap<>(); 346 for(int num: nums){ 347 if(map.containsKey(num)) 348 map.put(num, map.get(num) + 1); 349 else 350 map.put(num, 1); 351 } 352 353 PriorityQueue<Freq> pq = new PriorityQueue<>(); 354 for(int key: map.keySet()){ 355 if(pq.getSize() < k) 356 pq.enqueue(new Freq(key, map.get(key))); 357 else if(map.get(key) > pq.getFront().freq){ 358 pq.dequeue(); 359 pq.enqueue(new Freq(key, map.get(key))); 360 } 361 } 362 363 LinkedList<Integer> res = new LinkedList<>(); 364 while(!pq.isEmpty()) 365 res.add(pq.dequeue().e); 366 return res; 367 } 368 369 private static void printList(List<Integer> nums){ 370 for(Integer num: nums) 371 System.out.print(num + " "); 372 System.out.println(); 373 } 374 375 public static void main(String[] args) { 376 377 int[] nums = {1, 1, 1, 2, 2, 3}; 378 int k = 2; 379 printList((new Solution()).topKFrequent(nums, k)); 380 } 381 }
在以上代码中我们需要关心的是如下部分
1 private class Freq implements Comparable<Freq>{ 2 3 public int e, freq; 4 5 public Freq(int e, int freq){ 6 this.e = e; 7 this.freq = freq; 8 } 9 10 @Override 11 public int compareTo(Freq another){ 12 if(this.freq < another.freq) 13 return 1; 14 else if(this.freq > another.freq) 15 return -1; 16 else 17 return 0; 18 } 19 } 20 21 public List<Integer> topKFrequent(int[] nums, int k) { 22 23 TreeMap<Integer, Integer> map = new TreeMap<>(); 24 for(int num: nums){ 25 if(map.containsKey(num)) 26 map.put(num, map.get(num) + 1); 27 else 28 map.put(num, 1); 29 } 30 31 PriorityQueue<Freq> pq = new PriorityQueue<>(); 32 for(int key: map.keySet()){ 33 if(pq.getSize() < k) 34 pq.enqueue(new Freq(key, map.get(key))); 35 else if(map.get(key) > pq.getFront().freq){ 36 pq.dequeue(); 37 pq.enqueue(new Freq(key, map.get(key))); 38 } 39 } 40 41 LinkedList<Integer> res = new LinkedList<>(); 42 while(!pq.isEmpty()) 43 res.add(pq.dequeue().e); 44 return res; 45 }
我们将完整代码提交到leetcode
得到如下结果,表示我们验证自己实现的优先队列成功了。
这盛世,如您所愿。
如有错误的地方还请留言指正。
原创不易,转载请注明原文地址:https://www.cnblogs.com/hello-shf/p/11397386.html