带中位数写法的快速排序再讨论 & leetcode 215. Kth Largest Element in an Array题解

带中位数写法的快速排序再讨论 & leetcode 215. Kth Largest Element in an Array题解

 

探讨带中位数的写法本身

class Solution {
public:
    int findKthLargest(std::vector<int>& nums, int k) {
        return fakeQuickSort(nums, k, 0, nums.size() - 1);
    }
private:
	int fakeQuickSort(std::vector<int>& nums, int k, int l, int r) {
		int x = find_pivot_down(nums, l, r);
		int i = l, j = r;
		while(i < j) {
			do i++; while(nums[i] > x);
			do j--; while(nums[j] < x);
			if(i < j) std::swap(nums[i], nums[j]);
		}
		std::swap(nums[i], nums[r]);
		if(r - l <= 2) {
			return nums[l + k - 1];
		}
		int cnt = i - l + 1;
		if(k < cnt) {
			return fakeQuickSort(nums, k, l, i);
		}
		else {
			return fakeQuickSort(nums, k - cnt, i + 1, r);
		}
	}
	int find_pivot_up(std::vector<int>& nums, int l, int r) {
		int mid = l + r + 1 >> 1;
		if(nums[mid] < nums[l])
			std::swap(nums[mid], nums[l]);	//	make sure nums[mid] >= nums[l]
		if(nums[l] > nums[r]) 
			std::swap(nums[l], nums[r]); //	make sure nums[l] <= nums[r]
		if(nums[mid] < nums[r])
			std::swap(nums[mid], nums[r]);	//	make sure nums[mid] >= nums[r]
		//	nums[l] <= nums[r] <= nums[mid]
		return nums[r];
	}
	int find_pivot_down(std::vector<int>& nums, int l, int r) {
		int mid = l + r + 1 >> 1;
		if(nums[l] < nums[mid])
			std::swap(nums[l], nums[mid]);	// make sure nums[l] >= nums[mid]
		if(nums[l] < nums[r])
			std::swap(nums[l], nums[r]);	// make sure nums[l] >= nums[r]
		if(nums[r] < nums[mid])
			std::swap(nums[r], nums[mid]);	// make sure nums[r] >= nums[mid]
		//	nums[mid] <= nums[r] <= nums[l]
		return nums[r];
	}
};

这里展示的是这道题的通过代码,不过本质上和放排序代码一样,偷个懒

 

当前选取中位数写法的问题

从前刚认识这种写法的时候,觉得这种写法非常厉害,能够提高效率,但是经过一段时间后再回过头来看,发现这种写法有这样几个不妥之处:

1.升序和降序排列的find_pivot需要两种写法,非常麻烦

2.如果面试的时候让你手搓,那么越长的代码越容易出错

 

边界情况分析

为什么在find_pivot函数当中,取中点需要写mid = l + r + 1 >> 1

以升序排序的情况为例(降序同理):

首先,为了回答上面的问题,我们要先搞清楚我们为什么要保证最后的不等关系是nums[l] <= nums[r] <= nums[mid]

在执行这个函数之后,我们把中位数放在了nums[r],然后在快排结束之后,我们直接把i位置的数换到r过来,因为nums[i]<=pivot,所以这一步是合法的;

而我们一开始就跳过了l和r,所以我们需要保证l位置上的数,即使不动,也是满足快排终止条件要求的,所以l位置必须是小于等于pivot的数字,因此nums[l]只能是最小的。

确认了不等关系以后,我们就能根据边界情况来确认mid的取法里要不要+1.

当区间长度等于2的时候,如果不+1,这个时候mid就是l,导致我们的不等式不能成立了,所以我们要让mid取到r,因此需要+1

 

结合这道题目的分析

为什么在这种写法的框架下,我们需要这样丑陋的特判?

if(r - l <= 2) {
			return nums[l + k - 1];
		}

因为我们的通解情况

if(k < cnt) {
			return fakeQuickSort(nums, k, l, i);
		}
		else {
			return fakeQuickSort(nums, k - cnt, i + 1, r);
		}

如果k < cnt,并且这个时候区间长度只剩下了2,这时候l + 1 = r = i,区间长度不会再缩小了,会陷入死循环,我们需要手动加入特判来跳出循环

我们这个特判可以写的更好看一点点

if(k < cnt) {
			return fakeQuickSort(nums, k, l, i - 1);
		}
		else if (k > cnt){
			return fakeQuickSort(nums, k - cnt, i + 1, r);
		}
        else {
            return x;
        }

发散思维

255. 第K小数 - AcWing题库

这道题y总写出了非常漂亮的特判,看起来只有两种分支;但实际上代码开头的l==r也是一种分支,所以也是三种;

而在acwing这道题当中我们的mid是手算的,我们可以确定,[l,mid]和[mid+1,r]都不会导致死循环

但是这道题当中,在区间长度只剩下2的情况下,i最后会落在l,我们如果写[l,i]的递归,就会陷入死循环;

所以我们不得不采用现在的写法。

而且,这道题对快排的写法还有一个要求:

我们必须明确地知道,最后循环结束时,i位置上数字和中位数的大小关系,否则我们不知道它能不能被舍去;

因此这里要么用固定中位数写法,然后手动把它放到i上;

要么就用选取中位数写法,选择出中位数,并且手动把中位数归位;

但是又因为前一种写法更容易被卡,所以我们选择后面一种。

再审视

我们可以发现,我们这道题里用到的写法只能舍去非法数字,不能正确实现排序,只能用来解决这里的问题。

我们能不能设计出一个正确的、并且知道pivot具体位置的快速排序写法呢?同时,它还不能选取固定pivot(因为容易被卡)。

即它既能通过快速排序原题,又可以解决第k大(小)数问题?

我们分析一下发现,三路快速排序很好地满足了这个性质,但是三路快速排序会做很多没有必要的交换操作,导致在一般情况下性能不如普通快排.

而因为我们快排的时候,可能这个序列是我们在这一轮当中做了交换才变成现在的样子,如果不确定我们的中位数具体是哪一个,我们就没有办法确认,j或者i上的数字和pivot的大小关系,也不知道pivot现在具体在哪,就不知道怎么处理边界问题。

因此,这道题要么写丑陋的带特判写法,要么就写没有那么快的三路快速排序

快排性质

同时,在分析过程中我们发现,虽然我们的代码当中

while(i < j) {
			do i++;
			while(nums[i] <x);
			do j--;
			while(nums[j] >x);
            if (i < j) std::swap(nums[i], nums[j]);
		}

这里都是大于号、小于号,但是由于我们循环的过程中可能会触发交换,所以j右边的数字并不是全都大于x,而是大于等于x;对i同理

这一点对所有的快速排序模板都是适用的;

我们也发现,yxc的写法采用j作为分界线是因为用i可能会陷入死循环。

总结

综上所述,我所熟悉的ycx快排模板只适合单纯地快排,因为我们捕捉不到pivot最终具体在哪,改造成固定pivot之后才行:

但是固定pivot在一般的题目里面很容易被卡TLE;

这道题三路快排最后一个点也会被卡TLE

而繁琐的版本适当改写之后,也能通过原本的快排测试,因此,这个版本综合来看是最优的

class Solution {
   public:
    int findKthLargest(std::vector<int>& nums, int k) {
        return fakeQuickSort(nums, k, 0, nums.size() - 1);
    }

    void quickSort(std::vector<int>& nums, int l, int r) {
        if (l >= r) {
            return;
        }
        int x = find_pivot_up(nums, l, r);
        int i = l, j = r;
        while (i < j) {
            do i++;
            while (nums[i] < x);
            do j--;
            while (nums[j] > x);
            if (i < j) std::swap(nums[i], nums[j]);
        }
        std::swap(nums[i], nums[r]);
        DEBUG_LOG("l:%d, r:%d, i:%d, j:%d, x:%d\n", l, r, i, j, x);
        for (int i = l; i <= r; i++) DEBUG_LOG("%d ", nums[i]);
        DEBUG_LOG("\n");
        quickSort(nums, l, i - 1);
        quickSort(nums, i + 1, r);
    }

   private:
    int fakeQuickSort(std::vector<int>& nums, int k, int l, int r) {
		int x = find_pivot_down(nums, l, r);
		int i = l, j = r;
		while(i < j) {
			do i++; while(nums[i] > x);
			do j--; while(nums[j] < x);
			if(i < j) std::swap(nums[i], nums[j]);
		}
		std::swap(nums[i], nums[r]);
		if(r - l <= 2) {
			return nums[l + k - 1];
		}
		DEBUG_LOG("l:%d r:%d k:%d i:%d x:%d\n", l, r, k, i, x);
		for(int i = l; i <= r; i++) {
			DEBUG_LOG("%d ", nums[i]);
		}
		DEBUG_LOG("\n");
		int cnt = i - l + 1;
		if(k < cnt) {
			return fakeQuickSort(nums, k, l, i - 1);
		}
		else {
			return fakeQuickSort(nums, k - cnt, i + 1, r);
		}
	}
	int tripleSortFindKth(std::vector<int>& nums, int k, int l, int r) {
		int i = l, j = l, s = r;
		int x = nums[(l + r) >> 1];
		while(i <= s) {
			if(nums[i] > x) {
				std::swap(nums[i++], nums[j++]);
			}
			else if(nums[i] < x) {
				std::swap(nums[i], nums[s--]);
			}
			else {
				i++;
			}
		}
		int cnt = j - l + 1;
		if(cnt == k) {
			return nums[j];
		}
		else if(cnt > k) {
			return tripleSortFindKth(nums, k, l, j - 1);
		}
		else {
			return tripleSortFindKth(nums, k - cnt, j + 1, r);
		}
	}
    int find_pivot(std::vector<int>& nums, int l, int r) {
        int &a = nums[l], &b = nums[r], &c = nums[(l + r) >> 1];
        if ((a >= b && a <= c) || (a >= c && a <= b)) {
            return a;  // a 是中位数
        } else if ((b >= a && b <= c) || (b >= c && b <= a)) {
            return b;  // b 是中位数
        } else {
            return c;  // c 是中位数
        }
    }
    int find_pivot_up(std::vector<int>& nums, int l, int r) {
		int mid = l + r + 1 >> 1;
		if(nums[mid] < nums[l])
			std::swap(nums[mid], nums[l]);	//	make sure nums[mid] >= nums[l]
		if(nums[l] > nums[r]) 
			std::swap(nums[l], nums[r]); //	make sure nums[l] <= nums[r]
		if(nums[mid] < nums[r])
			std::swap(nums[mid], nums[r]);	//	make sure nums[mid] >= nums[r]
		//	nums[l] <= nums[r] <= nums[mid]
		return nums[r];
	}
    int find_pivot_down(std::vector<int>& nums, int l, int r) {
        int mid = (l + r + 1) >> 1;
        if (nums[l] < nums[mid])
            std::swap(nums[l], nums[mid]);  // make sure nums[l] >= nums[mid]
        if (nums[l] < nums[r])
            std::swap(nums[l], nums[r]);  // make sure nums[l] >= nums[r]
        if (nums[r] < nums[mid])
            std::swap(nums[r], nums[mid]);  // make sure nums[r] >= nums[mid]
        //	nums[mid] <= nums[r] <= nums[l]
        return nums[r];
    }
};

当然,那个r-l>=2的区间式特判太丑陋了,换成这样的更美观

if(k < cnt) {
			return fakeQuickSort(nums, k, l, i - 1);
		}
		else if(k > cnt){
			return fakeQuickSort(nums, k - cnt, i + 1, r);
		}
		else {
			return x;
		}

这种快排写法的性质:

当到达边界情况len==2时,i一定会停在r处

 

整理好的代码

namespace templateQuickSort {

template <typename T>
using Comparatorfunc = std::function<bool(const T&, const T&)>;

template <typename T>
class Comparator {
   public:
    // 静态成员变量声明
    static std::map<std::string, Comparatorfunc<T>> sToC;

    // 定义比较函数
    static bool AscComparator(const T& a, const T& b) { return a < b; }

    static bool DesComparator(const T& a, const T& b) { return a > b; }
};

// 静态成员变量定义和初始化
template <typename T>
std::map<std::string, Comparatorfunc<T>> Comparator<T>::sToC = {
    {"Asc", Comparator<T>::AscComparator},
    {"Des", Comparator<T>::DesComparator}};

// 先引入模板函数,然后再加入自定义比较功能
// ascend递增,descend递减
template <typename T>
class Solution {
   public:
    int findKthLargest(std::vector<T>& nums, int k, std::string mode) {
		auto it = Comparator<T>::sToC.find(mode);
        
        // 如果找不到对应的比较模式,抛出异常
        if (it == Comparator<T>::sToC.end()) {
            throw std::invalid_argument("Invalid comparison mode: " + mode);
        }

        Comparatorfunc<T> cmp = Comparator<T>::sToC[mode];
        return fakeQuickSort(nums, k, 0, nums.size() - 1, cmp);
    }

    void quickSort(std::vector<T>& nums, int l, int r, std::string mode) {
		auto it = Comparator<T>::sToC.find(mode);
        
std::cout << typeid(it).name() << std::endl;
		exit(0);

        // 如果找不到对应的比较模式,抛出异常
        if (it == Comparator<T>::sToC.end()) {
            throw std::invalid_argument("Invalid comparison mode: " + mode);
        }

        Comparatorfunc<T> cmp = Comparator<T>::sToC[mode];
        realQuickSort(nums, l, r, cmp);
    }

   private:
    void realQuickSort(std::vector<T>& nums, int l, int r,
                       Comparatorfunc<T> cmp) {
        if (l >= r) {
            return;
        }
        int x = find_pivot(nums, l, r, cmp);
        int i = l - 1, j = r + 1;
        while (i < j) {
            do i++;
            while (cmp(nums[i], x));
            do j--;
            while (!cmp(nums[j], x));
            if (i < j) std::swap(nums[i], nums[j]);
        }
        std::swap(nums[i], nums[r]);
        realQuickSort(nums, l, i - 1, cmp);
        realQuickSort(nums, i + 1, r, cmp);
    }

    int fakeQuickSort(std::vector<T>& nums, int k, int l, int r,
                      Comparatorfunc<T> cmp) {
        int x = find_pivot(nums, l, r, cmp);
        int i = l, j = r;
        while (i < j) {
            do i++;
            while (cmp(nums[i], nums[j]));
            do j--;
            while (!cmp(nums[i], nums[j]));
            if (i < j) std::swap(nums[i], nums[j]);
        }
        std::swap(nums[i], nums[r]);
        int cnt = i - l + 1;
        if (k < cnt) {
            return fakeQuickSort(nums, k, l, i - 1, cmp);
        } else if (k > cnt) {
            return fakeQuickSort(nums, k - cnt, i + 1, r, cmp);
        } else {
            return x;
        }
    }

    int find_pivot(std::vector<T>& nums, int l, int r, Comparatorfunc<T> cmp) {
        int mid = l + r + 1 >> 1;
        if (cmp(nums[mid], nums[l]))
            std::swap(nums[mid], nums[l]);  //	make sure nums[mid] >= nums[l]
        if (!cmp(nums[l], nums[r]))
            std::swap(nums[l], nums[r]);  //	make sure nums[l] <= nums[r]
        if (cmp(nums[mid], nums[r]))
            std::swap(nums[mid], nums[r]);  //	make sure nums[mid] >= nums[r]
        //	nums[l] <= nums[r] <= nums[mid]
        return nums[r];
    }
};

}  // namespace templateQuickSort

另外,如果在普通的快排写法中,我们要以i为分界点,那么两个子区间就是(l,i-1)和(i,r)

因为前面已经分析过,在区间长度为2的边界情况下,i一定会停在r处,如果用(l,i)和(i+1,r)就会导致死循环

posted @ 2024-10-13 22:22  Gold_stein  阅读(2)  评论(0编辑  收藏  举报