并行但并不十分高效的 for_each 实现

class ThreadsJoiner
{
    std::vector<std::thread>& threads;
public:
    ThreadsJoiner(std::vector<std::thread>& threads_):
        threads(threads_)
    {}
    
    ~ThreadsJoiner()
    {
        for (auto& thread : threads)
        {
            if (thread.joinable())
            {
                thread.join();
            }
        }
    }
};

template<typename Iterator, typename Func>
void parallelForEach(Iterator first, Iterator last, Func f)
{
    size_t const length            = std::distance(first, last);
    if (length < 0)
    {
        return;
    }
    
    size_t constexpr minDataPerThread  = 25;
    size_t const     hardwareThreadNum = std::thread::hardware_concurrency();
    size_t const     maxThreadNum      =
        (length + minDataPerThread - 1) / minDataPerThread;
    size_t const     threadNum         = std::min(maxThreadNum,
                                                  hardwareThreadNum == 0
                                                  ? 2
                                                  : hardwareThreadNum);
    size_t const blockSize = length / threadNum;
    std::vector<std::future<void>> futures(threadNum - 1);
    std::vector<std::thread>       threads(threadNum - 1);
    ThreadsJoiner joiner(threads);
    
    auto blockBegin = first;
    for (size_t i = 0; i < (threadNum - 1); ++i)
    {
        auto blockEnd = blockBegin;
        std::advance(blockEnd, blockSize);
        std::packaged_task<void(void)> task([=]{std::for_each(blockBegin,
                                                              blockEnd, f);});
        futures[i] = task.get_future();
        threads[i] = std::thread(std::move(task));
        blockBegin = blockEnd;
    }
    
    std::for_each(blockBegin, last, f);
    for (auto& future : futures)
    {
        future.get();
    }
}

template<typename Iterator, typename Func>
void parallelForEach(Iterator first, Iterator last, Func f)
{
    size_t const            length   = std::distance(first, last);
    if (length == 0)
    {
        return;
    }
    static constexpr size_t thunkSize = 25;
    if (length <= thunkSize)
    {
        std::for_each(first, last, f);
    }
    else
    {
        auto midIt  = first + length / 2;
        auto firstHalf = std::async(&parallelForEach<Iterator, Func>,
                                 first, midIt, f);
        parallelForEach(midIt, last, f);
        firstHalf.get();
    }
}

 

posted @ 2015-11-03 17:24  wu_overflow  阅读(373)  评论(0编辑  收藏  举报