C++索引从0开始的堆排序算法实现
更新2019年11月4日 04:26:35
睡不着觉起来寻思寻思干点啥吧,好像好久没写堆排了。于是写了个索引从0开始的堆排,这次把建堆函数略了并在heapsort主函数里,索引从0开始到size-1结束,长度size。
这个堆排和索引从1开始的堆排区别就是对于节点i,两个子节点分别为2i+1和2i+2。另外建堆时从索引size/2-1开始倒序维护大顶堆。下面证明下这个起始索引的节点一定对应着二叉树的最后的一个或两个叶子节点。
1.siz是偶数,那么最后一个内部节点只有左子树。siz/2-1乘2等于siz-2,siz-2加1为最后一个节点的索引,是siz-1末尾没问题。
2.siz是奇数,最后一个内部节点有左右子树。siz/2-1乘2等于((siz-1)/2-1)乘2等于siz-1-2等于siz-3,那么左子节点siz-3+1=siz-2,右子节点siz-3+2=siz-1是末尾也没问题。
void max_heap(vector<int>& nums, int i, int limit);
void heap_sort(vector<int>& nums)
{
for (int i = nums.size() / 2 - 1; i >= 0; --i) //建最大堆
{
max_heap(nums, i, nums.size());
}
for (int i = 1; i < nums.size(); ++i)
{
int temp = nums[0];
nums[0] = nums[nums.size() - i];
nums[nums.size() - i] = temp;
max_heap(nums, 0, nums.size() - i);
}
}
void max_heap(vector<int>& nums, int i, int limit)
{
int max_index = i;
if (2 * i + 1 < limit and nums[2 * i + 1] > nums[max_index])
{
max_index = 2 * i + 1;
}
if (2 * i + 2 < limit and nums[2 * i + 2] > nums[max_index])
{
max_index = 2 * i + 2;
}
if (i != max_index)
{
int temp = nums[i];
nums[i] = nums[max_index];
nums[max_index] = temp;
max_heap(nums, max_index, limit);
}
}
分割线
最近算法群里有人问堆排序的问题,我一想没想出来,就又看看算法导论堆排那章,跟着敲了一遍,感觉印象算是深刻了一些。
堆排序的几个函数:
1.maxheap(nums[ ],int i)
首先我们数组从1开始编号,为了方便。即i为父节点,nums[i]的左右孩子分别为2i和2i+1,算算对不?这个函数假设i的左右子树都已经为最大堆,只有nums[i]这个根节点可能有问题,下面上代码:
void maxheap(vector<int>& nums,int i,int limit)
//把以ri为根的子树调整为最大堆,假设其左右子树都已是最大堆
{
int largest=i;
if(2*i<=limit&&nums[i]<nums[2*i])
{
largest=2*i;
}
if(2*i+1<=limit&&nums[largest]<nums[2*i+1])
{
largest=2*i+1;
}
if(i!=largest)
{
swap(nums[i],nums[largest]);
maxheap(nums,largest,limit);
}
}
limit是数组最后一个元素的编号,看名字也看得出。函数做的事情就是先找i和他的俩孩子里最大的是哪个,如果最大的是i的其中一个孩子,那么就将其和i的值交换,这样大的就上去了,原来的nums[i]下来了,但是这个新下来的值可能与下面的左右子树构不成最大堆,此时下来的nums[i]和其新左右子树依然符合我们上面假设的条件这个函数假设i的左右子树都已经为最大堆,只有nums[i]这个根节点可能有问题,所以对其递归调用maxheap
2.createmaxheap(nums[ ]):
排序刚开始数组为无序,把它变成最大堆的函数。
代码:
void createmaxheap(vector<int>& nums)
{
int siz=nums.size();
nums.insert(nums.begin(),0);
//排序1~n
for(int i=siz/2;i>1;--i)
{
maxheap(nums,i,siz);
}
}
C的数组下标0开始的,那么我们前面插个0(插几都行),然后对1~n进行堆排序。
siz/2是最后一个非叶节点(就是编号最大的非叶节点),画个图好理解的,这就不提了。然后从这个节点开始依次往1靠,每次都把这个节点作为根节点的树变成最大堆,最后整个1~n就是最大堆了。
3.heapsort(nums[ ]):
这个更简单了:
void heapsort(vector<int>& nums)
{
createmaxheap(nums);
for(int i=nums.size()-1;i>1;--i)
{
swap(nums[1],nums[i]);
maxheap(nums,1,i-1);
}
}
先建最大堆,然后把nums[1]和nums[n]交换,然后对nums[1,n-1]继续maxheap,就是每次调整为最大堆,最大的就是第一个元素,把它和当前堆排序区间的末尾元素交换,这样循环n-1次就排好了(剩一个最小的就不用排了)。
全部代码:
C++:
#include<vector>
#include<time.h>
#include<iostream>
using namespace std;
void maxheap(vector<int>& nums,int i,int limit)
//把以ri为根的子树调整为最大堆,假设其左右子树都已是最大堆
{
int largest=i;
if(2*i<=limit&&nums[i]<nums[2*i])
{
largest=2*i;
}
if(2*i+1<=limit&&nums[largest]<nums[2*i+1])
{
largest=2*i+1;
}
if(i!=largest)
{
swap(nums[i],nums[largest]);
maxheap(nums,largest,limit);
}
}
void createmaxheap(vector<int>& nums)
{
int siz=nums.size();
nums.insert(nums.begin(),0);
//排序1~n
for(int i=siz/2;i>1;--i)
{
maxheap(nums,i,siz);
}
}
void heapsort(vector<int>& nums)
{
createmaxheap(nums);
for(int i=nums.size()-1;i>1;--i)
{
swap(nums[1],nums[i]);
maxheap(nums,1,i-1);
}
}
int main()
{
int l=100;
srand(time(NULL));
vector<int> p(l,0);
for(int i=0;i<l;++i)
{
p[i]=rand();
}
for(int i=0;i<l;++i)
{
cout<<p[i]<<" ";
}
cout<<endl<<endl;
heapsort(p);
for(int i=0;i<l;++i)
{
cout<<p[i]<<" ";
}
getchar();
}
Python:
import random
data_list=[]
for i in range(100):
data_list.append(random.uniform(1,200))
def max_heapify(data_list,i,limit):
temp=i
if 2*i<=limit and data_list[temp]>data_list[2*i]:
temp=2*i
if 2*i+1<=limit and data_list[temp]>data_list[2*i+1]:
temp=2*i+1
if temp!=i:
data_list[temp],data_list[i]=data_list[i],data_list[temp]
max_heapify(data_list,temp,limit)
def create_heap(data_list):
list_len=len(data_list)
data_list.insert(0,0)
for i in range((list_len-1)//2,0,-1):
max_heapify(data_list,i,list_len-1)
def heap_sort(data_list):
create_heap(data_list)
list_len=len(data_list)
for i in range(1,list_len-1):
data_list[1],data_list[list_len-i]=data_list[list_len-i],data_list[1]
max_heapify(data_list,1,list_len-i-1)
heap_sort(data_list)
for i in data_list:
print(i)