C++20协程解糖 - 动手实现协程2 - 实现co_await和co_return
在开始之前,我们先修复上一篇文章中的一个bug,SharedState::add_finish_callback中post_all_callbacks应当提前判断settled,否则会在未设置结果的情况下添加callback,callback也会被立即post
template<class T>
class SharedState : public SharedStateBase {
// ...
// private
void add_finish_callback(std::function<void(T&)> callback) {
finish_callbacks.push_back(std::move(callback));
if (settled) {
post_all_callbacks();
}
}
};
概述
今天我们要实现的东西包括
- 给schedular加上timer支持
- 给Future和Promise补充必要设施以支持C++20协程
如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢
开始动手
首先是Schedular的timer支持,我们这里使用一个简单的优先队列用来管理所有timer,并在poll函数中处理完当帧pending_state后,视情况sleep到最近的timer到期,并处理所有到期的timer
class Schedular {
// ...
// public
using timer_callback = std::function<void()>;
using timer_item = std::tuple<bool, float, chrono::time_point<chrono::steady_clock, chrono::duration<float>>, timer_callback>;
using timer = std::chrono::steady_clock;
struct timer_item_cmp {
bool operator()(const timer_item& a, const timer_item& b) const {
return std::get<2>(a) > std::get<2>(b);
}
};
// ...
// public
void poll() {
size_t sz = pending_states.size();
for (size_t i = 0; i != sz; i++) {
auto state = std::move(pending_states[i]);
state->invoke_all_callback();
}
pending_states.erase(pending_states.begin(), pending_states.begin() + sz);
if (timer_queue.empty()) {
return;
}
if (pending_states.empty()) { //如果pending_states为空,则可以sleep较长的时间,等待第一个将要完成的timer
std::this_thread::sleep_until(std::get<2>(timer_queue.front()));
auto now = timer::now();
do {
deal_one_timer();
} while (!timer_queue.empty() && std::get<2>(timer_queue.front()) <= now);
} else { //否则只能处理当帧到期的timer,不能sleep,要及时返回给caller,让caller及时下一次poll处理剩下的pending_states
auto now = timer::now();
while (!timer_queue.empty() && std::get<2>(timer_queue.front()) <= now) {
deal_one_timer();
}
}
}
void add_timer(bool repeat, float delay, timer_callback callback) {
auto cur_time = chrono::time_point_cast<chrono::duration<float>>(timer::now());
auto timeout = cur_time + chrono::duration<float>(delay);
timer_queue.emplace_back(repeat, delay, timeout, callback);
std::push_heap(timer_queue.begin(), timer_queue.end(), timer_item_cmp{});
}
// ...
// private
void deal_one_timer() {
std::pop_heap(timer_queue.begin(), timer_queue.end(), timer_item_cmp{});
auto item = std::move(timer_queue.back());
timer_queue.pop_back();
std::get<3>(item)();
if (std::get<0>(item)) {
add_timer(true, std::get<1>(item), std::move(std::get<3>(item)));
}
}
std::deque<timer_item> timer_queue;
};
这样之后,基于当前调度器的delay函数就可以写出来了
class Schedular {
// ...
// public
Future<float> delay(float second) {
auto promise = Promise<float>(*this);
add_timer(false, second, [=]() mutable {
promise.set_result(second);
});
return promise.get_future();
}
// ...
};
因为之前我们设计的Future和Promise并不支持void,于是这里简单用Future<float>代替,返回的是等待的秒数。
需要注意的是,这个delay函数虽然返回Future,但并不是协程,协程的判断标准是当且仅当函数中使用了co_await/co_yield/co_return,和返回类型无关。
这个函数同样展示了将回调式API封装为Future的做法,就是把Promise.set_result作为回调传入给API,并返回Promise.get_future,使用者在Future这边等待就好了。
有了这些东西之后,我们可以先把本次的测试代码写出来了
如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢
Future<float> func(Schedular& schedular) {
std::cout << "start sleep\n";
auto r = co_await schedular.delay(1.2);
co_return r;
}
Future<int> func2(Schedular& schedular) {
auto r = co_await func(schedular);
std::cout << "slept for " << r << "s\n";
co_return 42;
}
这里需要注意的是C++协程在编译器实现中,会自动构造一个Promise对象,而我们的Promise并不支持默认构造,必须传入一个Schedular参数。好在C++会替我们自动将协程参数作为作为构造函数参数来构造Promise,因此要在协程参数中指定Schedular,相当于指定Schedular构造了Promise。为一个协程显式指定调度器,是一个很合理的设计,python也是类似的设计。C#将协程调度器隐藏进了Task,因为它有一个全局的默认调度器。如果我们的实现中提供一个全局构造的Schedular,让Promise自动去找他调度,那这里的协程也可以没有参数。
如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢
为了让Future支持协程,代码中还需要补充一系列的内容,列举在下面
template<class T>
class Future {
// ...
// public
// 协程接口
using promise_type = Promise<T>;
bool await_ready() { return _state->settled; }
void await_suspend(exp::coroutine_handle<> handle) {
add_finish_callback([=](T&) mutable { handle.resume(); });
}
T await_resume() { return _state->value; }
// 协程接口
// ...
};
- promise_type用来指定本future对应的promise,结果的输入端
- await_ready检查future是否已经完成
- await_suspend用来通知future,协程为了等待它完成,已经暂停,future需要在自己完成的时候,主动恢复协程
- await_resume用来通知future,协程已经恢复执行,需要从future中取出结果,用作co_await表达式的结果。这里我们直接返回拷贝,实现中较为合理的是把future持有的对象移动出去,但这样的话被await的future就不能再单独获取结果了。
为了让Promise支持协程,需要补充的内容在下面
// ...
// 在最开头
// 如果你的编译器已经不需要std::experimental了,那就去掉这行,后面使用std而不是exp
namespace exp = std::experimental;
template<class T>
class Promise {
// ...
// public
// 协程接口
Future<T> get_return_object();
exp::suspend_never initial_suspend() { return {}; }
exp::suspend_never final_suspend() noexcept { return {}; }
void return_value(T v) { set_result(v); }
void unhandled_exception() { std::terminate(); }
// 协程接口
// ...
};
// ...
// 在Future定义后面
template<class T>
Future<T> Promise<T>::get_return_object() {
return get_future();
}
- initial_suspend用来表明协程是否在调用时暂停,异步任务一般返回suepend_never,调用时立即启动
- final_suspend用来表明协程是否在co_return后暂停(延迟销毁),我们是使用shared_state的异步任务,因此可以不暂停协程,直接自动销毁协程,让shared_state留在空中靠引用计数清零销毁
- return_value用于co_return将结果传入
- unhandled_exception用于协程中出现了未处理异常的情况,这里面可以通过std::current_exception来获取当前异常,我们的简化版不可能出现异常,出了就直接terminate
- get_return_object就是get_future。大家不要忘记一个协程是先构造promise,后从promise获取future的
有了这些东西之后,编译就不应该再出现错误了,我的编译选项是 clang++-9 test.cpp -stdlib=libc++ -std=c++2a
运行?还差最后一点
为了方便,我们效法python,给Schedular补一个run_until_compete的方法
如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢
class Schedular {
// ...
// public
template<class F, class... Args>
auto run_until_complete(F&& fn, Args&&... args) -> typename std::invoke_result_t<F&&, Args&&...>::result_type {
auto future = std::forward<F>(fn)(std::forward<Args>(args)...);
while (!future.await_ready()) {
poll();
}
return future.await_resume();
}
};
然后main
int main() {
Schedular schedular;
auto r = schedular.run_until_complete(func2, schedular);
std::cout << "run complete with " << r << "\n";
}
运行结果就有了
start sleep
slept for 1.2s
run complete with 42
怎么样,是不是很简单呢,赶紧自己写一个吧!
如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢
附录 - 全部代码
#include <vector>
#include <deque>
#include <memory>
#include <iostream>
#include <functional>
#include <chrono>
#include <thread>
#include <algorithm>
#include <experimental/coroutine>
namespace exp = std::experimental;
template<class T>
class Future;
template<class T>
class Promise;
class Schedular;
class SharedStateBase : public std::enable_shared_from_this<SharedStateBase> {
friend class Schedular;
public:
virtual ~SharedStateBase() = default;
private:
virtual void invoke_all_callback() = 0;
};
template<class T>
class SharedState : public SharedStateBase {
friend class Future<T>;
friend class Promise<T>;
public:
SharedState(Schedular& schedular)
: schedular(&schedular)
{}
SharedState(const SharedState&) = delete;
SharedState(SharedState&&) = delete;
SharedState& operator=(const SharedState&) = delete;
SharedState& operator=(SharedState&&) = delete;
private:
template<class U>
void set(U&& v) {
if (settled) {
return;
}
settled = true;
value = std::forward<U>(v);
post_all_callbacks();
}
T& get() { return value; }
void add_finish_callback(std::function<void(T&)> callback) {
finish_callbacks.push_back(std::move(callback));
if (settled) {
post_all_callbacks();
}
}
void post_all_callbacks();
virtual void invoke_all_callback() override {
callback_posted = false;
size_t sz = finish_callbacks.size();
for (size_t i = 0; i != sz; i++) {
auto v = std::move(finish_callbacks[i]);
v(value);
}
finish_callbacks.erase(finish_callbacks.begin(), finish_callbacks.begin()+sz);
}
bool settled = false;
bool callback_posted = false;
Schedular* schedular = nullptr;
T value;
std::vector<std::function<void(T&)>> finish_callbacks;
};
template<class T>
class Promise {
public:
Promise(Schedular& schedular)
: _schedular(&schedular)
, _state(std::make_shared<SharedState<T>>(*_schedular))
{}
Future<T> get_future();
// 协程接口
Future<T> get_return_object();
exp::suspend_never initial_suspend() { return {}; }
exp::suspend_never final_suspend() noexcept { return {}; }
void return_value(T v) { set_result(v); }
void unhandled_exception() { std::terminate(); }
// 协程接口
template<class U>
void set_result(U&& value) {
if (_state->settled) {
throw std::invalid_argument("already set result");
}
_state->set(std::forward<U>(value));
}
private:
Schedular* _schedular;
std::shared_ptr<SharedState<T>> _state;
};
template<class T>
class Future {
public:
using result_type = T;
using promise_type = Promise<T>;
friend class Promise<T>;
private:
Future(std::shared_ptr<SharedState<T>> state)
: _state(std::move(state))
{
}
public:
// 协程接口
bool await_ready() { return _state->settled; }
void await_suspend(exp::coroutine_handle<> handle) {
add_finish_callback([=](T&) mutable { handle.resume(); });
}
T await_resume() { return _state->value; }
// 协程接口
void add_finish_callback(std::function<void(T&)> callback) {
_state->add_finish_callback(std::move(callback));
}
private:
std::shared_ptr<SharedState<T>> _state;
};
template<class T>
Future<T> Promise<T>::get_future() {
return Future<T>(_state);
}
template<class T>
Future<T> Promise<T>::get_return_object() {
return get_future();
}
namespace chrono = std::chrono;
class Schedular {
template<class T>
friend class SharedState;
public:
using timer_callback = std::function<void()>;
using timer_item = std::tuple<bool, float, chrono::time_point<chrono::steady_clock, chrono::duration<float>>, timer_callback>;
using timer = std::chrono::steady_clock;
struct timer_item_cmp {
bool operator()(const timer_item& a, const timer_item& b) const {
return std::get<2>(a) > std::get<2>(b);
}
};
Schedular() = default;
Schedular(Schedular&&) = delete;
Schedular(const Schedular&) = delete;
Schedular& operator=(Schedular&&) = delete;
Schedular& operator=(const Schedular&) = delete;
void poll() {
size_t sz = pending_states.size();
for (size_t i = 0; i != sz; i++) {
auto state = std::move(pending_states[i]);
state->invoke_all_callback();
}
pending_states.erase(pending_states.begin(), pending_states.begin() + sz);
if (timer_queue.empty()) {
return;
}
if (pending_states.empty()) { //如果pending_states为空,则可以sleep较长的时间,等待第一个将要完成的timer
std::this_thread::sleep_until(std::get<2>(timer_queue.front()));
auto now = timer::now();
do {
deal_one_timer();
} while (!timer_queue.empty() && std::get<2>(timer_queue.front()) <= now);
} else { //否则只能处理当帧到期的timer,不能sleep,要及时返回给caller,让caller及时下一次poll处理剩下的pending_states
auto now = timer::now();
while (!timer_queue.empty() && std::get<2>(timer_queue.front()) <= now) {
deal_one_timer();
}
}
}
template<class F, class... Args>
auto run_until_complete(F&& fn, Args&&... args) -> typename std::invoke_result_t<F&&, Args&&...>::result_type {
auto future = std::forward<F>(fn)(std::forward<Args>(args)...);
while (!future.await_ready()) {
poll();
}
return future.await_resume();
}
void add_timer(bool repeat, float delay, timer_callback callback) {
auto cur_time = chrono::time_point_cast<chrono::duration<float>>(timer::now());
auto timeout = cur_time + chrono::duration<float>(delay);
timer_queue.emplace_back(repeat, delay, timeout, callback);
std::push_heap(timer_queue.begin(), timer_queue.end(), timer_item_cmp{});
}
Future<float> delay(float second) {
auto promise = Promise<float>(*this);
add_timer(false, second, [=]() mutable {
promise.set_result(second);
});
return promise.get_future();
}
private:
void deal_one_timer() {
std::pop_heap(timer_queue.begin(), timer_queue.end(), timer_item_cmp{});
auto item = std::move(timer_queue.back());
timer_queue.pop_back();
std::get<3>(item)();
if (std::get<0>(item)) {
add_timer(true, std::get<1>(item), std::move(std::get<3>(item)));
}
}
void post_call_state(std::shared_ptr<SharedStateBase> state) {
pending_states.push_back(std::move(state));
}
std::vector<std::shared_ptr<SharedStateBase>> pending_states;
std::deque<timer_item> timer_queue;
};
template<class T>
void SharedState<T>::post_all_callbacks() {
if (callback_posted) {
return;
}
callback_posted = true;
schedular->post_call_state(shared_from_this());
}
Future<float> func(Schedular& schedular) {
std::cout << "start sleep\n";
auto r = co_await schedular.delay(1.2);
co_return r;
}
Future<int> func2(Schedular& schedular) {
auto r = co_await func(schedular);
std::cout << "slept for " << r << "s\n";
co_return 42;
}
int main() {
Schedular schedular;
auto r = schedular.run_until_complete(func2, schedular);
std::cout << "run complete with " << r << "\n";
}