带中位数写法的快速排序再讨论 & leetcode 215. Kth Largest Element in an Array题解
带中位数写法的快速排序再讨论 & leetcode 215. Kth Largest Element in an Array题解
探讨带中位数的写法本身
class Solution {
public:
int findKthLargest(std::vector<int>& nums, int k) {
return fakeQuickSort(nums, k, 0, nums.size() - 1);
}
private:
int fakeQuickSort(std::vector<int>& nums, int k, int l, int r) {
int x = find_pivot_down(nums, l, r);
int i = l, j = r;
while(i < j) {
do i++; while(nums[i] > x);
do j--; while(nums[j] < x);
if(i < j) std::swap(nums[i], nums[j]);
}
std::swap(nums[i], nums[r]);
if(r - l <= 2) {
return nums[l + k - 1];
}
int cnt = i - l + 1;
if(k < cnt) {
return fakeQuickSort(nums, k, l, i);
}
else {
return fakeQuickSort(nums, k - cnt, i + 1, r);
}
}
int find_pivot_up(std::vector<int>& nums, int l, int r) {
int mid = l + r + 1 >> 1;
if(nums[mid] < nums[l])
std::swap(nums[mid], nums[l]); // make sure nums[mid] >= nums[l]
if(nums[l] > nums[r])
std::swap(nums[l], nums[r]); // make sure nums[l] <= nums[r]
if(nums[mid] < nums[r])
std::swap(nums[mid], nums[r]); // make sure nums[mid] >= nums[r]
// nums[l] <= nums[r] <= nums[mid]
return nums[r];
}
int find_pivot_down(std::vector<int>& nums, int l, int r) {
int mid = l + r + 1 >> 1;
if(nums[l] < nums[mid])
std::swap(nums[l], nums[mid]); // make sure nums[l] >= nums[mid]
if(nums[l] < nums[r])
std::swap(nums[l], nums[r]); // make sure nums[l] >= nums[r]
if(nums[r] < nums[mid])
std::swap(nums[r], nums[mid]); // make sure nums[r] >= nums[mid]
// nums[mid] <= nums[r] <= nums[l]
return nums[r];
}
};
这里展示的是这道题的通过代码,不过本质上和放排序代码一样,偷个懒
当前选取中位数写法的问题
从前刚认识这种写法的时候,觉得这种写法非常厉害,能够提高效率,但是经过一段时间后再回过头来看,发现这种写法有这样几个不妥之处:
1.升序和降序排列的find_pivot需要两种写法,非常麻烦
2.如果面试的时候让你手搓,那么越长的代码越容易出错
边界情况分析
为什么在find_pivot函数当中,取中点需要写mid = l + r + 1 >> 1
以升序排序的情况为例(降序同理):
首先,为了回答上面的问题,我们要先搞清楚我们为什么要保证最后的不等关系是nums[l] <= nums[r] <= nums[mid]
在执行这个函数之后,我们把中位数放在了nums[r],然后在快排结束之后,我们直接把i位置的数换到r过来,因为nums[i]<=pivot,所以这一步是合法的;
而我们一开始就跳过了l和r,所以我们需要保证l位置上的数,即使不动,也是满足快排终止条件要求的,所以l位置必须是小于等于pivot的数字,因此nums[l]只能是最小的。
确认了不等关系以后,我们就能根据边界情况来确认mid的取法里要不要+1.
当区间长度等于2的时候,如果不+1,这个时候mid就是l,导致我们的不等式不能成立了,所以我们要让mid取到r,因此需要+1
结合这道题目的分析
为什么在这种写法的框架下,我们需要这样丑陋的特判?
if(r - l <= 2) {
return nums[l + k - 1];
}
因为我们的通解情况
if(k < cnt) {
return fakeQuickSort(nums, k, l, i);
}
else {
return fakeQuickSort(nums, k - cnt, i + 1, r);
}
如果k < cnt,并且这个时候区间长度只剩下了2,这时候l + 1 = r = i,区间长度不会再缩小了,会陷入死循环,我们需要手动加入特判来跳出循环
我们这个特判可以写的更好看一点点
if(k < cnt) {
return fakeQuickSort(nums, k, l, i - 1);
}
else if (k > cnt){
return fakeQuickSort(nums, k - cnt, i + 1, r);
}
else {
return x;
}
发散思维
这道题y总写出了非常漂亮的特判,看起来只有两种分支;但实际上代码开头的l==r也是一种分支,所以也是三种;
而在acwing这道题当中我们的mid是手算的,我们可以确定,[l,mid]和[mid+1,r]都不会导致死循环
但是这道题当中,在区间长度只剩下2的情况下,i最后会落在l,我们如果写[l,i]的递归,就会陷入死循环;
所以我们不得不采用现在的写法。
而且,这道题对快排的写法还有一个要求:
我们必须明确地知道,最后循环结束时,i位置上数字和中位数的大小关系,否则我们不知道它能不能被舍去;
因此这里要么用固定中位数写法,然后手动把它放到i上;
要么就用选取中位数写法,选择出中位数,并且手动把中位数归位;
但是又因为前一种写法更容易被卡,所以我们选择后面一种。
再审视
我们可以发现,我们这道题里用到的写法只能舍去非法数字,不能正确实现排序,只能用来解决这里的问题。
我们能不能设计出一个正确的、并且知道pivot具体位置的快速排序写法呢?同时,它还不能选取固定pivot(因为容易被卡)。
即它既能通过快速排序原题,又可以解决第k大(小)数问题?
我们分析一下发现,三路快速排序很好地满足了这个性质,但是三路快速排序会做很多没有必要的交换操作,导致在一般情况下性能不如普通快排.
而因为我们快排的时候,可能这个序列是我们在这一轮当中做了交换才变成现在的样子,如果不确定我们的中位数具体是哪一个,我们就没有办法确认,j或者i上的数字和pivot的大小关系,也不知道pivot现在具体在哪,就不知道怎么处理边界问题。
因此,这道题要么写丑陋的带特判写法,要么就写没有那么快的三路快速排序
快排性质
同时,在分析过程中我们发现,虽然我们的代码当中
while(i < j) {
do i++;
while(nums[i] <x);
do j--;
while(nums[j] >x);
if (i < j) std::swap(nums[i], nums[j]);
}
这里都是大于号、小于号,但是由于我们循环的过程中可能会触发交换,所以j右边的数字并不是全都大于x,而是大于等于x;对i同理
这一点对所有的快速排序模板都是适用的;
我们也发现,yxc的写法采用j作为分界线是因为用i可能会陷入死循环。
总结
综上所述,我所熟悉的ycx快排模板只适合单纯地快排,因为我们捕捉不到pivot最终具体在哪,改造成固定pivot之后才行:
但是固定pivot在一般的题目里面很容易被卡TLE;
这道题三路快排最后一个点也会被卡TLE
而繁琐的版本适当改写之后,也能通过原本的快排测试,因此,这个版本综合来看是最优的
class Solution {
public:
int findKthLargest(std::vector<int>& nums, int k) {
return fakeQuickSort(nums, k, 0, nums.size() - 1);
}
void quickSort(std::vector<int>& nums, int l, int r) {
if (l >= r) {
return;
}
int x = find_pivot_up(nums, l, r);
int i = l, j = r;
while (i < j) {
do i++;
while (nums[i] < x);
do j--;
while (nums[j] > x);
if (i < j) std::swap(nums[i], nums[j]);
}
std::swap(nums[i], nums[r]);
DEBUG_LOG("l:%d, r:%d, i:%d, j:%d, x:%d\n", l, r, i, j, x);
for (int i = l; i <= r; i++) DEBUG_LOG("%d ", nums[i]);
DEBUG_LOG("\n");
quickSort(nums, l, i - 1);
quickSort(nums, i + 1, r);
}
private:
int fakeQuickSort(std::vector<int>& nums, int k, int l, int r) {
int x = find_pivot_down(nums, l, r);
int i = l, j = r;
while(i < j) {
do i++; while(nums[i] > x);
do j--; while(nums[j] < x);
if(i < j) std::swap(nums[i], nums[j]);
}
std::swap(nums[i], nums[r]);
if(r - l <= 2) {
return nums[l + k - 1];
}
DEBUG_LOG("l:%d r:%d k:%d i:%d x:%d\n", l, r, k, i, x);
for(int i = l; i <= r; i++) {
DEBUG_LOG("%d ", nums[i]);
}
DEBUG_LOG("\n");
int cnt = i - l + 1;
if(k < cnt) {
return fakeQuickSort(nums, k, l, i - 1);
}
else {
return fakeQuickSort(nums, k - cnt, i + 1, r);
}
}
int tripleSortFindKth(std::vector<int>& nums, int k, int l, int r) {
int i = l, j = l, s = r;
int x = nums[(l + r) >> 1];
while(i <= s) {
if(nums[i] > x) {
std::swap(nums[i++], nums[j++]);
}
else if(nums[i] < x) {
std::swap(nums[i], nums[s--]);
}
else {
i++;
}
}
int cnt = j - l + 1;
if(cnt == k) {
return nums[j];
}
else if(cnt > k) {
return tripleSortFindKth(nums, k, l, j - 1);
}
else {
return tripleSortFindKth(nums, k - cnt, j + 1, r);
}
}
int find_pivot(std::vector<int>& nums, int l, int r) {
int &a = nums[l], &b = nums[r], &c = nums[(l + r) >> 1];
if ((a >= b && a <= c) || (a >= c && a <= b)) {
return a; // a 是中位数
} else if ((b >= a && b <= c) || (b >= c && b <= a)) {
return b; // b 是中位数
} else {
return c; // c 是中位数
}
}
int find_pivot_up(std::vector<int>& nums, int l, int r) {
int mid = l + r + 1 >> 1;
if(nums[mid] < nums[l])
std::swap(nums[mid], nums[l]); // make sure nums[mid] >= nums[l]
if(nums[l] > nums[r])
std::swap(nums[l], nums[r]); // make sure nums[l] <= nums[r]
if(nums[mid] < nums[r])
std::swap(nums[mid], nums[r]); // make sure nums[mid] >= nums[r]
// nums[l] <= nums[r] <= nums[mid]
return nums[r];
}
int find_pivot_down(std::vector<int>& nums, int l, int r) {
int mid = (l + r + 1) >> 1;
if (nums[l] < nums[mid])
std::swap(nums[l], nums[mid]); // make sure nums[l] >= nums[mid]
if (nums[l] < nums[r])
std::swap(nums[l], nums[r]); // make sure nums[l] >= nums[r]
if (nums[r] < nums[mid])
std::swap(nums[r], nums[mid]); // make sure nums[r] >= nums[mid]
// nums[mid] <= nums[r] <= nums[l]
return nums[r];
}
};
当然,那个r-l>=2的区间式特判太丑陋了,换成这样的更美观
if(k < cnt) {
return fakeQuickSort(nums, k, l, i - 1);
}
else if(k > cnt){
return fakeQuickSort(nums, k - cnt, i + 1, r);
}
else {
return x;
}
这种快排写法的性质:
当到达边界情况len==2时,i一定会停在r处
整理好的代码
namespace templateQuickSort {
template <typename T>
using Comparatorfunc = std::function<bool(const T&, const T&)>;
template <typename T>
class Comparator {
public:
// 静态成员变量声明
static std::map<std::string, Comparatorfunc<T>> sToC;
// 定义比较函数
static bool AscComparator(const T& a, const T& b) { return a < b; }
static bool DesComparator(const T& a, const T& b) { return a > b; }
};
// 静态成员变量定义和初始化
template <typename T>
std::map<std::string, Comparatorfunc<T>> Comparator<T>::sToC = {
{"Asc", Comparator<T>::AscComparator},
{"Des", Comparator<T>::DesComparator}};
// 先引入模板函数,然后再加入自定义比较功能
// ascend递增,descend递减
template <typename T>
class Solution {
public:
int findKthLargest(std::vector<T>& nums, int k, std::string mode) {
auto it = Comparator<T>::sToC.find(mode);
// 如果找不到对应的比较模式,抛出异常
if (it == Comparator<T>::sToC.end()) {
throw std::invalid_argument("Invalid comparison mode: " + mode);
}
Comparatorfunc<T> cmp = Comparator<T>::sToC[mode];
return fakeQuickSort(nums, k, 0, nums.size() - 1, cmp);
}
void quickSort(std::vector<T>& nums, int l, int r, std::string mode) {
auto it = Comparator<T>::sToC.find(mode);
std::cout << typeid(it).name() << std::endl;
exit(0);
// 如果找不到对应的比较模式,抛出异常
if (it == Comparator<T>::sToC.end()) {
throw std::invalid_argument("Invalid comparison mode: " + mode);
}
Comparatorfunc<T> cmp = Comparator<T>::sToC[mode];
realQuickSort(nums, l, r, cmp);
}
private:
void realQuickSort(std::vector<T>& nums, int l, int r,
Comparatorfunc<T> cmp) {
if (l >= r) {
return;
}
int x = find_pivot(nums, l, r, cmp);
int i = l - 1, j = r + 1;
while (i < j) {
do i++;
while (cmp(nums[i], x));
do j--;
while (!cmp(nums[j], x));
if (i < j) std::swap(nums[i], nums[j]);
}
std::swap(nums[i], nums[r]);
realQuickSort(nums, l, i - 1, cmp);
realQuickSort(nums, i + 1, r, cmp);
}
int fakeQuickSort(std::vector<T>& nums, int k, int l, int r,
Comparatorfunc<T> cmp) {
int x = find_pivot(nums, l, r, cmp);
int i = l, j = r;
while (i < j) {
do i++;
while (cmp(nums[i], nums[j]));
do j--;
while (!cmp(nums[i], nums[j]));
if (i < j) std::swap(nums[i], nums[j]);
}
std::swap(nums[i], nums[r]);
int cnt = i - l + 1;
if (k < cnt) {
return fakeQuickSort(nums, k, l, i - 1, cmp);
} else if (k > cnt) {
return fakeQuickSort(nums, k - cnt, i + 1, r, cmp);
} else {
return x;
}
}
int find_pivot(std::vector<T>& nums, int l, int r, Comparatorfunc<T> cmp) {
int mid = l + r + 1 >> 1;
if (cmp(nums[mid], nums[l]))
std::swap(nums[mid], nums[l]); // make sure nums[mid] >= nums[l]
if (!cmp(nums[l], nums[r]))
std::swap(nums[l], nums[r]); // make sure nums[l] <= nums[r]
if (cmp(nums[mid], nums[r]))
std::swap(nums[mid], nums[r]); // make sure nums[mid] >= nums[r]
// nums[l] <= nums[r] <= nums[mid]
return nums[r];
}
};
} // namespace templateQuickSort
另外,如果在普通的快排写法中,我们要以i为分界点,那么两个子区间就是(l,i-1)和(i,r)
因为前面已经分析过,在区间长度为2的边界情况下,i一定会停在r处,如果用(l,i)和(i+1,r)就会导致死循环