TopK
面试到了一个topk,这个原理很简单,但是以前很少写过。面试时写的有点小慢,没有达到行云流水的地步。于是回来再写一遍练练。其中,堆排序部分采用简明排序代码。用完整的TopK代码:
#include <iostream>
#include <algorithm>
using namespace std;
template<typename T>
void unguarded_heapify(T *data, size_t size, size_t top)
{
while (true)
{
size_t min = top;
if (top * 2 < size && data[top * 2] < data[min])
{
min = top * 2;
}
if (top * 2 + 1 < size && data[top * 2 + 1] < data[min])
{
min = top * 2 + 1;
}
if (top == min) return;
swap(data[top], data[min]);
top = min;
}
}
template<typename T>
void make_min_heap(T *begin, T *end)
{
if (begin == NULL || end == NULL)
{
return;
}
if (begin == end || begin + 1 == end)
{
return;
}
size_t len = end - begin;
for (size_t top = len / 2; top >= 1; --top)
{
// Special offset.
unguarded_heapify(begin - 1, len + 1, top);
}
}
void topk(const int *begin, const int *end, int *buffer, size_t *k)
{
if (begin == NULL || end == NULL || buffer == NULL || k == NULL)
{
return;
}
if (begin == end || *k == 0)
{
return;
}
memset(buffer, 0, *k * sizeof(int));
const int *p = begin;
int *dest = buffer;
while (p != begin + *k && p != end)
{
*dest++ = *p++;
}
if (p == end)
{
*k = end - begin;
}
else
{
make_min_heap(buffer, dest);
while (p != end)
{
if (*p > *buffer)
{
*buffer = *p;
unguarded_heapify(buffer - 1, *k + 1, 1);
}
++p;
}
}
}
int main(int argc, char **argv)
{
int data[] = {4, 5, 1, 3, 5, 6, 7, 2};
int *result = new int[10];
size_t k = 3;
topk(data, data + sizeof(data) / sizeof(data[0]), result, &k);
copy(result, result + k, ostream_iterator<int>(cout, " "));
cout << endl;
k = 10;
topk(data, data + sizeof(data) / sizeof(data[0]), result, &k);
copy(result, result + k, ostream_iterator<int>(cout, " "));
cout << endl;
k = 1;
topk(data, data + sizeof(data) / sizeof(data[0]), result, &k);
copy(result, result + k, ostream_iterator<int>(cout, " "));
cout << endl;
k = 8;
topk(data, data + sizeof(data) / sizeof(data[0]), result, &k);
copy(result, result + k, ostream_iterator<int>(cout, " "));
cout << endl;
k = 0;
topk(data, data + sizeof(data) / sizeof(data[0]), result, &k);
copy(result, result + k, ostream_iterator<int>(cout, " "));
cout << endl;
delete[] result;
return 0;
}