C++20协程解糖 - 动手实现协程1 - Future和Promise

std::future和promise在C++20里面没法直接用的唯一原因就是不支持then,虽然MSVC有一个弱智版开线程阻塞实现的future.then,能then了但不保序,而且libstdc++也用不了。folly之类的库有靠谱的实现,但是功能太齐全太复杂,不适合新手学习。因此我们先从弱智版future promise schedular开始,从源头讲解如何实现协程相关设施。

如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢

基本结构

我们要实现的功能很简单:

  1. 单线程模型
  2. promise是入口,future是出口
  3. promise支持set_result, get_future
  4. future支持add_finish_callback,在promise set_result之后按序调用
  5. callback在下次调度时调用而不是立即调用

结构大概是这样

image

状态都存在shared_state里面,future和promise实际上只是个空壳


先搭框架

如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢

首先是shared_state,最直观的,shared_state需要存储最终设置的结果T,以及记录结果有没有设置。这里要求T必须支持默认构造,省事。


template<class T>
class SharedState {
    friend class Future<T>;
    friend class Promise<T>;
public:
    SharedState()
    {}
    SharedState(const SharedState&) = delete;
    SharedState(SharedState&&) = delete;
    SharedState& operator=(const SharedState&) = delete;
    SharedState& operator=(SharedState&&) = delete;

private:
    template<class U>
    void set(U&& v) {
        if (settled) {
            return;
        }
        settled = true;
        value = std::forward<U>(v);
    }

    bool settled = false;
    T value;
};


然后是引用了shared_state的promise和future

template<class T>
class Promise {
public:
    Promise()
        : _state(std::make_shared<SharedState<T>>())
    {}

    Future<T> get_future();

    template<class U>
    void set_result(U&& value) {
        if (_state->settled) {
            throw std::invalid_argument("already set result");
        }
        _state->set(std::forward<U>(value));
    }
private:
    std::shared_ptr<SharedState<T>> _state;
};

template<class T>
class Future {
    friend class Promise<T>;
private:
    Future(std::shared_ptr<SharedState<T>> state)
        : _state(std::move(state))
    {
    }
private:
    std::shared_ptr<SharedState<T>> _state;
};

template<class T>
Future<T> Promise<T>::get_future() {
    return Future<T>(_state);
}


先把调度器的壳写上,以后会用到

class Schedular {
    template<class T>
    friend class SharedState;
public:
    Schedular() = default;
    Schedular(Schedular&&) = delete;
    Schedular(const Schedular&) = delete;
    Schedular& operator=(Schedular&&) = delete;
    Schedular& operator=(const Schedular&) = delete;

    void poll() {
        // TODO
    }
};



再补功能

如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢

首先future要支持add_finish_callback


template<class T>
class Future {
    // ...
    // public
    void add_finish_callback(std::function<void(T&)> callback) {
        _state->add_finish_callback(std::move(callback));
    }
};

为什么要把callback实际加到_state里面去呢,因为之后需要post所有callback给schedular,而schedular要接受各种Future<T>的callback,要做类型擦除太麻烦了,所以索性,把callback存入_state,_state正好是堆对象,把他的类型擦除了,丢给schedular去post,还是因为简单

既然callback实际加给了shared_state,那SharedState也得补充对应的功能


template<class T>
class SharedState {
    // ...
    // private
    void add_finish_callback(std::function<void(T&)> callback) {
        finish_callbacks.push_back(std::move(callback));
        // TODO
    }

    std::vector<std::function<void(T&)>> finish_callbacks;
};


然后,就是在shared_state set的时候,或者shared_state已有结果,但刚刚新增了callback的时候,把shared_state自己发送给Schedular,等待下一帧被调度到时调用callback

为此,SharedState本身要存储一个Schedular的指针,那么Promise就得接受Schedular作为构造函数参数,SharedState还要记录自己是否已经被post给Schedular,不需要重复post


template<class T>
class SharedState {
    // ...
    // public
    // 构造函数增加一个参数
    SharedState(Schedular& schedular)
        : schedular(&schedular)
    {}

    // ...
    // private
    // set增加内容
    template<class U>
    void set(U&& v) {
        if (settled) {
            return;
        }
        settled = true;
        value = std::forward<U>(v);
        post_all_callbacks();
    }

    void add_finish_callback(std::function<void(T&)> callback) {
        finish_callbacks.push_back(std::move(callback));
        post_all_callbacks();
    }

    void post_all_callbacks();

    bool settled = false;
    bool callback_posted = false;
    Schedular* schedular = nullptr;
    T value;
    std::vector<std::function<void(T&)>> finish_callbacks;
};

template<class T>
class Promise {
    // ...
	// public
	// 构造函数增加参数
    Promise(Schedular& schedular)
        : _schedular(&schedular)
        , _state(std::make_shared<SharedState<T>>(*_schedular))
    {}
};

// ...
// 在Schedular定义的后面
template<class T>
void SharedState<T>::post_all_callbacks() {
    if (callback_posted) {
        return;
    }
    callback_posted = true;
    schedular->post_call_state(shared_from_this());
}

可以发现,在post_all_callback时,SharedState把自己shared_from_this()后发送给了schedular,显然这里既要enable_shared_from_this,又要类型擦除,于是

class SharedStateBase : public std::enable_shared_from_this<SharedStateBase> {
    friend class Schedular;
public:
    virtual ~SharedStateBase() = default;
};

template<class T>
class SharedState : public SharedStateBase {
    // ...
};


Schedular里面也得补充存储shared_ptr<SharedStateBase>的东西

class Schedular {
    // ...
    // private
    void post_call_state(std::shared_ptr<SharedStateBase> state) {
        pending_states.push_back(std::move(state));
    }

    std::vector<std::shared_ptr<SharedStateBase>> pending_states;
};


下面,就轮到Schedular的poll函数在每帧调用被post过来的SharedStateBase了

class Schedular {
    // ...
    // public
    void poll() {
        size_t sz = pending_states.size();
        for (size_t i = 0; i != sz; i++) {
            auto state = std::move(pending_states[i]);
            state->invoke_all_callback();
        }
        pending_states.erase(pending_states.begin(), pending_states.begin()+sz);
    }
};


  • 之所以这里使用下标循环,是因为迭代过程中还可能有callback继续往pending_states里面新增元素
  • 之所以不用while !empty()而是预先获取size,是为了避免调度是callback内无限post callback导致无限循环
  • 之所以调用前先move出来,是为了避免调用callback期间callback继续往pending_states里面新增元素导致容器扩容,内容物失效

这里对state调用了invoke_all_callback,显然这是一个虚函数,需要给SharedState补上


class SharedStateBase : public std::enable_shared_from_this<SharedStateBase> {
    // ...
    // private
    virtual void invoke_all_callback() = 0;
};

template<class T>
class SharedState : public SharedStateBase {
    // ...
    // private
    virtual void invoke_all_callback() override {
        callback_posted = false;
        size_t sz = finish_callbacks.size();
        for (size_t i = 0; i != sz; i++) {
            auto v = std::move(finish_callbacks[i]);
            v(value);
        }
        finish_callbacks.erase(finish_callbacks.begin(), finish_callbacks.begin()+sz);
    }
};


这里在invoke_all_callbacks里面,使用了和上面schedular poll里面类似的代码结构,可以让本帧callback里新增的callback post到下一帧调用

好了,全部功能都齐了,下面可以测试了,完整的代码在文章最后贴出


测试

如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢


int main() {
    Schedular schedular;
    Promise<int> promise(schedular);
    Future<int> future = promise.get_future();
    std::cout << "future get\n";
    promise.set_result(10);
    std::cout << "promise result set\n";
    future.add_finish_callback([](int v) {
        std::cout << "callback 1 got result " << v << "\n";
    });
    std::cout << "future callback add\n";
    std::cout << "tick 1\n";
    schedular.poll();
    std::cout << "tick 2\n";
    future.add_finish_callback([](int v) {
        std::cout << "callback 2 got result " << v << "\n";
    });
    std::cout << "future callback 2 add\n";
    schedular.poll();

    std::cout << "\n";

    Promise<double> promise2(schedular);
    promise2.set_result(12.34);
    std::cout << "promise result2 set\n";
    Future<double> future2 = promise2.get_future();
    std::cout << "future2 get\n";
    future2.add_finish_callback([&](double v) {
        std::cout << "future2 callback 1 got result" << v << "\n";
        future2.add_finish_callback([](double v) {
            std::cout << "future2 callback 2 got result" << v << "\n";
        });
        std::cout << "future2 callback 2 add inside callback\n";
    });
    std::cout << "future2 callback add\n";
    std::cout << "tick 3\n";
    schedular.poll();
    std::cout << "tick 4\n";
    schedular.poll();
}


输出

future get
promise result set
future callback add
tick 1
callback 1 got result 10
tick 2
future callback 2 add
callback 2 got result 10

promise result2 set
future2 get
future2 callback add
tick 3
future2 callback 1 got result12.34
future2 callback 2 add inside callback
tick 4
future2 callback 2 got result12.34

怎么样,是不是很简单呢,赶紧自己回家造一个吧!


附录

如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢

完整代码

#include <vector>
#include <memory>
#include <iostream>
#include <functional>

template<class T>
class Future;

template<class T>
class Promise;

class Schedular;

class SharedStateBase : public std::enable_shared_from_this<SharedStateBase> {
    friend class Schedular;
public:
    virtual ~SharedStateBase() = default;
private:
    virtual void invoke_all_callback() = 0;
};

template<class T>
class SharedState : public SharedStateBase {
    friend class Future<T>;
    friend class Promise<T>;
public:
    SharedState(Schedular& schedular)
        : schedular(&schedular)
    {}
    SharedState(const SharedState&) = delete;
    SharedState(SharedState&&) = delete;
    SharedState& operator=(const SharedState&) = delete;
    SharedState& operator=(SharedState&&) = delete;

private:
    template<class U>
    void set(U&& v) {
        if (settled) {
            return;
        }
        settled = true;
        value = std::forward<U>(v);
        post_all_callbacks();
    }

    T& get() { return value; }

    void add_finish_callback(std::function<void(T&)> callback) {
        finish_callbacks.push_back(std::move(callback));
        post_all_callbacks();
    }

    void post_all_callbacks();

    virtual void invoke_all_callback() override {
        callback_posted = false;
        size_t sz = finish_callbacks.size();
        for (size_t i = 0; i != sz; i++) {
            auto v = std::move(finish_callbacks[i]);
            v(value);
        }
        finish_callbacks.erase(finish_callbacks.begin(), finish_callbacks.begin()+sz);
    }

    bool has_owner = false;
    bool settled = false;
    bool callback_posted = false;
    Schedular* schedular = nullptr;
    T value;
    std::vector<std::function<void(T&)>> finish_callbacks;
};

template<class T>
class Promise {
public:
    Promise(Schedular& schedular)
        : _schedular(&schedular)
        , _state(std::make_shared<SharedState<T>>(*_schedular))
    {}

    Future<T> get_future();

    template<class U>
    void set_result(U&& value) {
        if (_state->settled) {
            throw std::invalid_argument("already set result");
        }
        _state->set(std::forward<U>(value));
    }
private:
    Schedular* _schedular;
    std::shared_ptr<SharedState<T>> _state;
};

template<class T>
class Future {
    friend class Promise<T>;
private:
    Future(std::shared_ptr<SharedState<T>> state)
        : _state(std::move(state))
    {
    }
public:

    void add_finish_callback(std::function<void(T&)> callback) {
        _state->add_finish_callback(std::move(callback));
    }
private:
    std::shared_ptr<SharedState<T>> _state;
};

template<class T>
Future<T> Promise<T>::get_future() {
    return Future<T>(_state);
}

class Schedular {
    template<class T>
    friend class SharedState;
public:
    Schedular() = default;
    Schedular(Schedular&&) = delete;
    Schedular(const Schedular&) = delete;
    Schedular& operator=(Schedular&&) = delete;
    Schedular& operator=(const Schedular&) = delete;

    void poll() {
        size_t sz = pending_states.size();
        for (size_t i = 0; i != sz; i++) {
            auto state = std::move(pending_states[i]);
            state->invoke_all_callback();
        }
        pending_states.erase(pending_states.begin(), pending_states.begin()+sz);
    }
private:
    void post_call_state(std::shared_ptr<SharedStateBase> state) {
        pending_states.push_back(std::move(state));
    }

    std::vector<std::shared_ptr<SharedStateBase>> pending_states;
};

template<class T>
void SharedState<T>::post_all_callbacks() {
    if (callback_posted) {
        return;
    }
    callback_posted = true;
    schedular->post_call_state(shared_from_this());
}

int main() {
    Schedular schedular;
    Promise<int> promise(schedular);
    Future<int> future = promise.get_future();
    std::cout << "future get\n";
    promise.set_result(10);
    std::cout << "promise result set\n";
    future.add_finish_callback([](int v) {
        std::cout << "callback 1 got result " << v << "\n";
    });
    std::cout << "future callback add\n";
    std::cout << "tick 1\n";
    schedular.poll();
    std::cout << "tick 2\n";
    future.add_finish_callback([](int v) {
        std::cout << "callback 2 got result " << v << "\n";
    });
    std::cout << "future callback 2 add\n";
    schedular.poll();

    std::cout << "\n";

    Promise<double> promise2(schedular);
    promise2.set_result(12.34);
    std::cout << "promise result2 set\n";
    Future<double> future2 = promise2.get_future();
    std::cout << "future2 get\n";
    future2.add_finish_callback([&](double v) {
        std::cout << "future2 callback 1 got result" << v << "\n";
        future2.add_finish_callback([](double v) {
            std::cout << "future2 callback 2 got result" << v << "\n";
        });
        std::cout << "future2 callback 2 add inside callback\n";
    });
    std::cout << "future2 callback add\n";
    std::cout << "tick 3\n";
    schedular.poll();
    std::cout << "tick 4\n";
    schedular.poll();
}

posted on 2020-05-16 14:36  PointerSMQ  阅读(1272)  评论(0编辑  收藏  举报