今天看了MIT 算法导论的视频课程,学习top K问题的时候,按leetcode惯例,会用大顶堆来解决,但是现在学到的BFPRT算法,记录一下
PS,此算法的最坏结果是O(n)
PS,好厉害
1 package com.example; 2 3 import java.util.Arrays; 4 5 /** 6 * Hello world! 7 * 8 */ 9 public class App 10 { 11 public static int[] getMinKNumsByBFPRT(int[] arr, int k) { 12 if (k < 1 || k > arr.length) { 13 return arr; 14 } 15 int minKth = getMinKthByBFPRT(arr, k); 16 int[] res = new int[k]; 17 int index = 0; 18 19 for (int i = 0; i != arr.length; i++) { 20 if (arr[i] < minKth) { 21 res[index++] = arr[i]; 22 } 23 } 24 for (; index != res.length; index++) { 25 res[index] = minKth; 26 } 27 return res; 28 } 29 30 31 //找出比k小的前k个数 32 public static int getMinKthByBFPRT(int[] arr, int K) { 33 int[] copyArr = arr.clone(); 34 return select(copyArr, 0, copyArr.length - 1, K - 1); 35 } 36 37 //用划分值与k相比,依次递归排序 38 public static int select(int[] arr, int begin, int end, int i) { 39 if (begin == end) { //begin数组的开始 end数组的结尾 i表示要求的第k个数 40 return arr[begin]; 41 } 42 int pivot = medianOfMedians(arr, begin, end);//找出划分值(中位数组中的中位数) 43 int[] pivotRange = partition(arr, begin, end, pivot); 44 if (i >= pivotRange[0] && i <= pivotRange[1]) {//小于放左边,=放中间,大于放右边 45 return arr[i]; 46 } else if (i < pivotRange[0]) { 47 return select(arr, begin, pivotRange[0] - 1, i); 48 } else { 49 return select(arr, pivotRange[1] + 1, end, i); 50 } 51 } 52 53 54 //找出中位数组中的中位数 55 public static int medianOfMedians(int[] arr, int begin, int end) { 56 int num = end - begin + 1; 57 int offset = num % 5 == 0 ? 0 : 1; //分组:每组5个数,不满5个单独占一组 58 int[] mArr = new int[num / 5 + offset]; //mArr:中位数组成的数组 59 for (int i = 0; i < mArr.length; i++) { //计算分开后各数组的开始位置beginI 结束位置endI 60 int beginI = begin + i * 5; 61 int endI = beginI + 4; 62 mArr[i] = getMedian(arr, beginI, Math.min(end, endI));//对于最后一组(不满5个数),结束位置要选择end 63 } 64 return select(mArr, 0, mArr.length - 1, mArr.length / 2); 65 } 66 67 //划分过程,类似于快排 68 public static int[] partition(int[] arr, int begin, int end, int pivotValue) { 69 int small = begin - 1; 70 int cur = begin; 71 int big = end + 1; 72 while (cur != big) { 73 if (arr[cur] < pivotValue) { 74 swap(arr, ++small, cur++); 75 } else if (arr[cur] > pivotValue) { 76 swap(arr, cur, --big); 77 } else { 78 cur++; 79 } 80 } 81 int[] range = new int[2]; 82 range[0] = small + 1;//比划分值小的范围 83 range[1] = big - 1; //比划分值大的范围 84 85 return range; 86 } 87 88 //计算中位数 89 public static int getMedian(int[] arr, int begin, int end) { 90 insertionSort(arr, begin, end);//将数组中的5个数排序 91 int sum = end + begin; 92 int mid = (sum / 2) + (sum % 2); 93 return arr[mid]; 94 } 95 96 97 //数组中5个数排序(插入排序),这里必须用插入排序 98 public static void insertionSort(int[] arr, int begin, int end) { 99 for (int i = begin + 1; i != end + 1; i++) { 100 for (int j = i; j != begin; j--) { 101 if (arr[j - 1] > arr[j]) { 102 swap(arr, j - 1, j); 103 } else { 104 break; 105 } 106 } 107 } 108 } 109 110 111 //交换元素顺序 112 public static void swap(int[] arr, int index1, int index2) { 113 int tmp = arr[index1]; 114 arr[index1] = arr[index2]; 115 arr[index2] = tmp; 116 } 117 118 119 //打印结果 120 public static void printArray(int[] arr) { 121 for (int i = 0; i != arr.length; i++) { 122 System.out.print(arr[i] + " "); 123 } 124 System.out.println(); 125 } 126 127 public static void main(String[] args) { 128 int[] arr = { 6, 9, 1, 3, 1, 2, 2, 5, 6, 1, 3, 5, 9, 7, 2, 5, 6, 1, 9 }; 129 printArray(getMinKNumsByBFPRT(arr, 4)); 130 131 } 132 }