C++ 模板实现败者树,进行多路归并

项目需要实现一个败者树,今天研究了一下,附上实现代码。

几点说明:

1. 败者树思想及实现参考这里:http://www.cnblogs.com/benjamin-t/p/3325401.html

2. 多路归并中的“多路”的容器使用的是C语言数组 + 数组长度的实现(即

const ContainerType* ways, size_t num

),而没有用STL中的容器,这是因为项目需要如此,日后再改成STL容器;

3. _losers 存储下标,用的是 int 类型,还需要修改。程序中其他下标类型都是 size_t,但是这个 _losers 存的下标需要使用 -1 表示无效。

4. Foo 还不能用在 std::copy上,待修正;

5. 使用了 FooContainer<Foo> 类型以后,输出的时候不能直接输出,必须定义一个变量再输出,不知道为什么:

            //std::cout << data[i][j] << ", ";  
            Foo foo = data[i][j];
            std::cout << foo << ", ";

6. 把 const 变量作为 non-type template parameter 时,必须把该 const 变量定义在全局,并且加 extern,原因见这里:http://stackoverflow.com/questions/9183485/const-variable-as-non-type-template-parameter-variable-cannot-appear-in-a-const

7. 代码放在Github上:https://github.com/qinpeixi/code-pieces/blob/master/loser_tree.cpp

8. 一个专业级的实现在这里:https://github.com/MITRECND/snugglefish/blob/master/include/loserTree.hpp

 

 

代码如下:

#include <iostream>
#include <vector>
#include <iterator>
#include <string>
#include <sstream>
#include <cstdlib>
#include <cassert>
#include <stdexcept>

class Foo {
public:
    Foo() {}
    explicit Foo(int v): _v(v) {}
    Foo(const Foo& foo) { _v = foo._v; }
    int value() const { return _v; }
    Foo& operator=(const Foo& foo) { _v = foo._v; return *this; }
    bool operator==(const Foo& foo) { return _v == foo._v; }

private:
    int _v;
};
std::ostream& operator<<(std::ostream& os, Foo& foo) {
    return os << foo.value();
}

extern const Foo FOO_MAX(INT_MAX);

namespace  std {
template<>
class less<Foo> : std::binary_function<Foo, Foo, bool>
{
public:
    bool operator() (const Foo& x, const Foo& y) const {
        return x.value() < y.value();
    }
};
} // namespace std

template<class ValueType>
class FooContainer
{
public:
    ValueType operator[](size_t idx) const { return _container[idx]; }
    size_t size() const { return _container.size(); }
    void push_back(const ValueType& value) {
        _container.push_back(value);
    }

private:
    std::vector<ValueType> _container;
};

template< class ValueType,
          class ContainerType,
          const ValueType& forever_lose_value,
          class Compare = std::less<ValueType> >
class LoserTree
{
public:
    LoserTree(const ContainerType* ways, size_t num) :
        _num(num), _ways(ways), _indexes(new size_t[_num]),
        _data(new ValueType[_num]), _losers(new int[_num])
    {
        if (ways == NULL || num == 0) {
            delete[] _indexes;
            delete[] _data;
            delete[] _losers;
            throw std::invalid_argument("invalid ways or number of ways");
        }
        std::fill(_indexes, _indexes + _num, 0);
        std::fill(_losers, _losers + _num, -1);
        for (int way_idx = _num-1; way_idx >= 0; --way_idx) {
            if (_indexes[way_idx] == _ways[way_idx].size()) {
                _data[way_idx] = forever_lose_value;
            } else {
                _data[way_idx] = _ways[way_idx][_indexes[way_idx]];
            }
            adjust(way_idx);
        }
    }

    ~LoserTree() {
        delete[] _indexes;
        delete[] _losers;
        delete[] _data;
    }

    bool extract_one(ValueType& v) {
        int way_idx = _losers[0];
        if (_data[way_idx] == forever_lose_value)
            return false;
        v = _data[way_idx];
        if (++_indexes[way_idx] == _ways[way_idx].size()) {
            _data[way_idx] = forever_lose_value;
        } else {
            _data[way_idx] = _ways[way_idx][_indexes[way_idx]];
        }
        adjust(way_idx);
        return true;
    }

private:
    size_t _num;
    const ContainerType* _ways;
    size_t* _indexes;
    ValueType* _data;
    int* _losers;

    void adjust(int winner_idx) {
        using std::swap;
        // _losers[loser_idx_idx] is the index of the loser in _data
        int loser_idx_idx = (winner_idx + _num) / 2;
        while (loser_idx_idx != 0 && winner_idx != -1) {
            if (_losers[loser_idx_idx] == -1 ||
                    !Compare()(_data[winner_idx],  _data[_losers[loser_idx_idx]]))
                swap(winner_idx,_losers[loser_idx_idx]);
            loser_idx_idx /= 2;
        }
        _losers[0] = winner_idx;
    }
};

/*
 * input format:
 * 1 10 100 1000
 * 2 20 200 2000
 * 3 30 300
 * 4 40 400 4000 40000
 */
std::vector<std::vector<int> > get_input()
{
    std::vector<std::vector<int> > data;
    std::string line;
    while (std::getline(std::cin, line)) {
        std::vector<int> tmp_data;
        std::istringstream iss(line);
        std::copy(std::istream_iterator<int>(iss), std::istream_iterator<int>(), std::back_inserter(tmp_data));
        data.push_back(tmp_data);
    }

    for (size_t i = 0; i < data.size(); ++i) {
        std::copy(data[i].begin(), data[i].end(), std::ostream_iterator<int>(std::cout, ", "));
        std::cout << std::endl;
    }

    return data;
}

template<class ValueType, class ContainerType>
std::vector<ContainerType> generate_data()
{
    const int way_num = 20;
    std::vector<ContainerType> data(way_num);
    for (int num = 0; num < 10/*100000*/; ++num) {
        data[rand() % way_num].push_back(ValueType(num));
    }

    return data;
}

void test_foo()
{
    std::vector<FooContainer<Foo> > data = generate_data<Foo, FooContainer<Foo> >();
    /*
    for (size_t i = 0; i < data.size(); ++i) {
        for (size_t j = 0; j < data[i].size(); ++j) {
            //std::cout << data[i][j] << ", ";
            Foo foo = data[i][j];
            std::cout << foo << ", ";
        }
        std::cout << std::endl;
    }
    */

    LoserTree<Foo, FooContainer<Foo>, FOO_MAX> lt(data.data(), data.size());
    Foo v;
    Foo correct_res(0);
    while(lt.extract_one(v)) {
        //assert(v == correct_res);
        //correct_res = Foo(correct_res.value()+1);
        std::cout << v.value() << ", ";
    }
    std::cout << std::endl;
}

extern const int int_max = INT_MAX;
void test()
{
    std::vector<std::vector<int> > data = generate_data<int ,std::vector<int> >();
    LoserTree<int, std::vector<int>, int_max> lt(data.data(), data.size());
    int v;
    int correct_res(0);
    while (lt.extract_one(v)) {
        assert(v == correct_res++);
        std::cout << v << ", ";
    }
    std::cout << std::endl;
}

int main()
{
    try {
        //LoserTree<int, std::vector<int>, int_max> lt(NULL, 3);
        test_foo();
        test();
    } catch (const std::exception& exc){
        std::cerr << exc.what() << std::endl;
    }

    return 0;
}

 

posted on 2014-10-17 21:56  夜花烛  阅读(1381)  评论(0编辑  收藏  举报