[补充]归并排序(非递归)以及归并排序的更高效算法——自然归并排序
递归版归并排序
我们在 CLRS 中已经学会了归并排序的递归写法:
merge函数:
def merge(left, right): # prerequisite: both left and right is sorted list ret = [] i = 0 j = 0 while i < len(left) and j < len(right): if left[i] < right[j]: ret.append(left[i]) i += 1 else: ret.append(right[j]) j += 1 if i == len(left): for it in right[j:]: ret.append(it) else: for it in left[i:]: ret.append(it) print("after sort, left is {}, right is {}, ret is {}".format(left, right, ret)) return ret
mergeSort函数:
def mergeSort(arr): print("current arr is:{}".format(arr)) if len(arr) <= 1: return arr mid = len(arr) // 2 left = mergeSort(arr[:mid]) right = mergeSort(arr[mid:]) return merge(left, right)
但是,递归算法的常数因子很影响时间,转化成非递归版算法通常都是更优的解法,所以我们来实现一下非递归版
非递归版归并排序
实现的原理和递归版刚好相反,递归解法是将有序串一分为二直到每个串只有一个元素,然后再排序合并。而非递归版是默认有 n 个长度为 1 子串,然后将相邻的两个串两两排序并合并,直到合并成一个长度为 n 的子串。比如刚开始有 n 个子串,下一步是相邻的两个串两两排序并合并,构成 n/2 个长度为 2 的子串,然后再排序合并,形成 n/4 个长度为 4 的子串....直到生成一个长度为 n 的子串。
void mergeSort2(int n){ int s=2,i; while(s<=n){ i=0; while(i+s<=n){ merge(i,i+s-1,i+s/2-1); i+=s; } //处理末尾残余部分 merge(i,n-1,i+s/2-1); s*=2; } //最后再从头到尾处理一遍 merge(0,n-1,s/2-1); }
自然归并排序
通常,子串在排序之前就已经有序,我们需要记录下已经有序的子串的数量以及其各自的头尾位置,在排序时无需再对这些子串进行排序,直接合并即可。
实现记录有序子串函数是 pass() 函数,其中的 rec 用于记录有序子串的头尾位置,pass 函数返回有序串的个数。
#include<iostream> #include<algorithm> #include<vector> #include<cstdio> #include<cstring> std::vector<int> arr; //排序数组 std::vector<int> c; //记录排序子串的索引 int len; //当前串的长度 //扫描记录每个自然串的起始下标与自然串总个数+1 int pass () { int max = arr[0]; int num = 0; c[num++] = 0; for (int i = 1; i < len; i++){ if (max <= arr[i]) { max = arr[i]; } else { c[num++] = i; max = arr[i]; } } c[num++] = len; //让c多个尾巴且num+1,方便mergeSort //for (int i = 0; i <= num; i++) printf("%d ",c[i]); //printf("\n"); return num; } void merge (const int& s, const int& end1, const int& end2) { //s,end1, end2 分别对应第一个串的开头、第一个串的 //结尾、第二个串的结尾 int tmpArr[len]; int s1 = s, s2 = end1 + 1; //printf ("s1 = %d, s2 = %d\n", s1, s2); for (int i = s; i <= end2; i++) { if (s1 > end1) tmpArr[i] = arr[s2++]; else if (s2 > end2) tmpArr[i] = arr[s1++]; else if (arr[s1] < arr[s2]) tmpArr[i] = arr[s1++]; else tmpArr[i] = arr[s2++]; } for (int i = s; i <= end2; i++) arr[i] = tmpArr[i]; } void mergeSort () { int cnt = pass(); //for (int i = 0; i < 2 + cnt; i++) printf("%d", c[i]); //printf("cnt = %d\n", cnt); while (cnt != 2) { for (int i = 0; i < cnt; i = i + 2) { merge(c[i], c[i+1]-1, c[i+2]-1); } cnt = pass(); printf("cnt = %d\n", cnt); } } int main (void) { while(scanf("%d", &len) != EOF && len != 0) { arr.clear(); arr.resize(len); c.clear(); c.resize(len+1); for (int i = 0; i < len; i++) { scanf("%d", &arr[i]); } mergeSort(); for (int i = 0; i < len; i++) printf("%d ", arr[i]); printf("\n"); } return 0; }
python版本:
class Merge(): def __init__(self): self.tmp_arr = [0,0,0,0,0,0,0,0,0,0,0,0] def merge(self, arr, start, mid, end): for i in range(start, end+1): self.tmp_arr[i] = arr[i] j = mid + 1 k = start for i in range(start, end+1): if k > mid: arr[i] = self.tmp_arr[j] j += 1 elif j > end: arr[i] = self.tmp_arr[k] k += 1 elif self.tmp_arr[j] < self.tmp_arr[k]: arr[i] = self.tmp_arr[j] j += 1 else: arr[i] = self.tmp_arr[k] k += 1 def sort(self, arr, start, end): if end <= start: return mid = start + (end - start) / 2 self.sort(arr, start, mid) self.sort(arr, mid + 1, end) self.merge(arr, start, mid ,end) mobj = Merge() arr = [5,3,4,7,1,9,0,4,2,6,8] mobj.sort(arr, 0, len(arr)-1) for i in range(len(arr)): print arr[i]
参考
————全心全意投入,拒绝画地为牢