C++ 模板实现败者树,进行多路归并
项目需要实现一个败者树,今天研究了一下,附上实现代码。
几点说明:
1. 败者树思想及实现参考这里:http://www.cnblogs.com/benjamin-t/p/3325401.html
2. 多路归并中的“多路”的容器使用的是C语言数组 + 数组长度的实现(即
const ContainerType* ways, size_t num
),而没有用STL中的容器,这是因为项目需要如此,日后再改成STL容器;
3. _losers 存储下标,用的是 int 类型,还需要修改。程序中其他下标类型都是 size_t,但是这个 _losers 存的下标需要使用 -1 表示无效。
4. Foo 还不能用在 std::copy上,待修正;
5. 使用了 FooContainer<Foo> 类型以后,输出的时候不能直接输出,必须定义一个变量再输出,不知道为什么:
//std::cout << data[i][j] << ", ";
Foo foo = data[i][j];
std::cout << foo << ", ";
6. 把 const 变量作为 non-type template parameter 时,必须把该 const 变量定义在全局,并且加 extern,原因见这里:http://stackoverflow.com/questions/9183485/const-variable-as-non-type-template-parameter-variable-cannot-appear-in-a-const
7. 代码放在Github上:https://github.com/qinpeixi/code-pieces/blob/master/loser_tree.cpp
8. 一个专业级的实现在这里:https://github.com/MITRECND/snugglefish/blob/master/include/loserTree.hpp
代码如下:
#include <iostream> #include <vector> #include <iterator> #include <string> #include <sstream> #include <cstdlib> #include <cassert> #include <stdexcept> class Foo { public: Foo() {} explicit Foo(int v): _v(v) {} Foo(const Foo& foo) { _v = foo._v; } int value() const { return _v; } Foo& operator=(const Foo& foo) { _v = foo._v; return *this; } bool operator==(const Foo& foo) { return _v == foo._v; } private: int _v; }; std::ostream& operator<<(std::ostream& os, Foo& foo) { return os << foo.value(); } extern const Foo FOO_MAX(INT_MAX); namespace std { template<> class less<Foo> : std::binary_function<Foo, Foo, bool> { public: bool operator() (const Foo& x, const Foo& y) const { return x.value() < y.value(); } }; } // namespace std template<class ValueType> class FooContainer { public: ValueType operator[](size_t idx) const { return _container[idx]; } size_t size() const { return _container.size(); } void push_back(const ValueType& value) { _container.push_back(value); } private: std::vector<ValueType> _container; }; template< class ValueType, class ContainerType, const ValueType& forever_lose_value, class Compare = std::less<ValueType> > class LoserTree { public: LoserTree(const ContainerType* ways, size_t num) : _num(num), _ways(ways), _indexes(new size_t[_num]), _data(new ValueType[_num]), _losers(new int[_num]) { if (ways == NULL || num == 0) { delete[] _indexes; delete[] _data; delete[] _losers; throw std::invalid_argument("invalid ways or number of ways"); } std::fill(_indexes, _indexes + _num, 0); std::fill(_losers, _losers + _num, -1); for (int way_idx = _num-1; way_idx >= 0; --way_idx) { if (_indexes[way_idx] == _ways[way_idx].size()) { _data[way_idx] = forever_lose_value; } else { _data[way_idx] = _ways[way_idx][_indexes[way_idx]]; } adjust(way_idx); } } ~LoserTree() { delete[] _indexes; delete[] _losers; delete[] _data; } bool extract_one(ValueType& v) { int way_idx = _losers[0]; if (_data[way_idx] == forever_lose_value) return false; v = _data[way_idx]; if (++_indexes[way_idx] == _ways[way_idx].size()) { _data[way_idx] = forever_lose_value; } else { _data[way_idx] = _ways[way_idx][_indexes[way_idx]]; } adjust(way_idx); return true; } private: size_t _num; const ContainerType* _ways; size_t* _indexes; ValueType* _data; int* _losers; void adjust(int winner_idx) { using std::swap; // _losers[loser_idx_idx] is the index of the loser in _data int loser_idx_idx = (winner_idx + _num) / 2; while (loser_idx_idx != 0 && winner_idx != -1) { if (_losers[loser_idx_idx] == -1 || !Compare()(_data[winner_idx], _data[_losers[loser_idx_idx]])) swap(winner_idx,_losers[loser_idx_idx]); loser_idx_idx /= 2; } _losers[0] = winner_idx; } }; /* * input format: * 1 10 100 1000 * 2 20 200 2000 * 3 30 300 * 4 40 400 4000 40000 */ std::vector<std::vector<int> > get_input() { std::vector<std::vector<int> > data; std::string line; while (std::getline(std::cin, line)) { std::vector<int> tmp_data; std::istringstream iss(line); std::copy(std::istream_iterator<int>(iss), std::istream_iterator<int>(), std::back_inserter(tmp_data)); data.push_back(tmp_data); } for (size_t i = 0; i < data.size(); ++i) { std::copy(data[i].begin(), data[i].end(), std::ostream_iterator<int>(std::cout, ", ")); std::cout << std::endl; } return data; } template<class ValueType, class ContainerType> std::vector<ContainerType> generate_data() { const int way_num = 20; std::vector<ContainerType> data(way_num); for (int num = 0; num < 10/*100000*/; ++num) { data[rand() % way_num].push_back(ValueType(num)); } return data; } void test_foo() { std::vector<FooContainer<Foo> > data = generate_data<Foo, FooContainer<Foo> >(); /* for (size_t i = 0; i < data.size(); ++i) { for (size_t j = 0; j < data[i].size(); ++j) { //std::cout << data[i][j] << ", "; Foo foo = data[i][j]; std::cout << foo << ", "; } std::cout << std::endl; } */ LoserTree<Foo, FooContainer<Foo>, FOO_MAX> lt(data.data(), data.size()); Foo v; Foo correct_res(0); while(lt.extract_one(v)) { //assert(v == correct_res); //correct_res = Foo(correct_res.value()+1); std::cout << v.value() << ", "; } std::cout << std::endl; } extern const int int_max = INT_MAX; void test() { std::vector<std::vector<int> > data = generate_data<int ,std::vector<int> >(); LoserTree<int, std::vector<int>, int_max> lt(data.data(), data.size()); int v; int correct_res(0); while (lt.extract_one(v)) { assert(v == correct_res++); std::cout << v << ", "; } std::cout << std::endl; } int main() { try { //LoserTree<int, std::vector<int>, int_max> lt(NULL, 3); test_foo(); test(); } catch (const std::exception& exc){ std::cerr << exc.what() << std::endl; } return 0; }