C++20协程解糖 - 动手实现协程1 - Future和Promise
std::future和promise在C++20里面没法直接用的唯一原因就是不支持then,虽然MSVC有一个弱智版开线程阻塞实现的future.then,能then了但不保序,而且libstdc++也用不了。folly之类的库有靠谱的实现,但是功能太齐全太复杂,不适合新手学习。因此我们先从弱智版future promise schedular开始,从源头讲解如何实现协程相关设施。
如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢
基本结构
我们要实现的功能很简单:
- 单线程模型
- promise是入口,future是出口
- promise支持set_result, get_future
- future支持add_finish_callback,在promise set_result之后按序调用
- callback在下次调度时调用而不是立即调用
结构大概是这样
状态都存在shared_state里面,future和promise实际上只是个空壳
先搭框架
如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢
首先是shared_state,最直观的,shared_state需要存储最终设置的结果T,以及记录结果有没有设置。这里要求T必须支持默认构造,省事。
template<class T>
class SharedState {
friend class Future<T>;
friend class Promise<T>;
public:
SharedState()
{}
SharedState(const SharedState&) = delete;
SharedState(SharedState&&) = delete;
SharedState& operator=(const SharedState&) = delete;
SharedState& operator=(SharedState&&) = delete;
private:
template<class U>
void set(U&& v) {
if (settled) {
return;
}
settled = true;
value = std::forward<U>(v);
}
bool settled = false;
T value;
};
然后是引用了shared_state的promise和future
template<class T>
class Promise {
public:
Promise()
: _state(std::make_shared<SharedState<T>>())
{}
Future<T> get_future();
template<class U>
void set_result(U&& value) {
if (_state->settled) {
throw std::invalid_argument("already set result");
}
_state->set(std::forward<U>(value));
}
private:
std::shared_ptr<SharedState<T>> _state;
};
template<class T>
class Future {
friend class Promise<T>;
private:
Future(std::shared_ptr<SharedState<T>> state)
: _state(std::move(state))
{
}
private:
std::shared_ptr<SharedState<T>> _state;
};
template<class T>
Future<T> Promise<T>::get_future() {
return Future<T>(_state);
}
先把调度器的壳写上,以后会用到
class Schedular {
template<class T>
friend class SharedState;
public:
Schedular() = default;
Schedular(Schedular&&) = delete;
Schedular(const Schedular&) = delete;
Schedular& operator=(Schedular&&) = delete;
Schedular& operator=(const Schedular&) = delete;
void poll() {
// TODO
}
};
再补功能
如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢
首先future要支持add_finish_callback
template<class T>
class Future {
// ...
// public
void add_finish_callback(std::function<void(T&)> callback) {
_state->add_finish_callback(std::move(callback));
}
};
为什么要把callback实际加到_state里面去呢,因为之后需要post所有callback给schedular,而schedular要接受各种Future<T>的callback,要做类型擦除太麻烦了,所以索性,把callback存入_state,_state正好是堆对象,把他的类型擦除了,丢给schedular去post,还是因为简单
既然callback实际加给了shared_state,那SharedState也得补充对应的功能
template<class T>
class SharedState {
// ...
// private
void add_finish_callback(std::function<void(T&)> callback) {
finish_callbacks.push_back(std::move(callback));
// TODO
}
std::vector<std::function<void(T&)>> finish_callbacks;
};
然后,就是在shared_state set的时候,或者shared_state已有结果,但刚刚新增了callback的时候,把shared_state自己发送给Schedular,等待下一帧被调度到时调用callback
为此,SharedState本身要存储一个Schedular的指针,那么Promise就得接受Schedular作为构造函数参数,SharedState还要记录自己是否已经被post给Schedular,不需要重复post
template<class T>
class SharedState {
// ...
// public
// 构造函数增加一个参数
SharedState(Schedular& schedular)
: schedular(&schedular)
{}
// ...
// private
// set增加内容
template<class U>
void set(U&& v) {
if (settled) {
return;
}
settled = true;
value = std::forward<U>(v);
post_all_callbacks();
}
void add_finish_callback(std::function<void(T&)> callback) {
finish_callbacks.push_back(std::move(callback));
post_all_callbacks();
}
void post_all_callbacks();
bool settled = false;
bool callback_posted = false;
Schedular* schedular = nullptr;
T value;
std::vector<std::function<void(T&)>> finish_callbacks;
};
template<class T>
class Promise {
// ...
// public
// 构造函数增加参数
Promise(Schedular& schedular)
: _schedular(&schedular)
, _state(std::make_shared<SharedState<T>>(*_schedular))
{}
};
// ...
// 在Schedular定义的后面
template<class T>
void SharedState<T>::post_all_callbacks() {
if (callback_posted) {
return;
}
callback_posted = true;
schedular->post_call_state(shared_from_this());
}
可以发现,在post_all_callback时,SharedState把自己shared_from_this()后发送给了schedular,显然这里既要enable_shared_from_this,又要类型擦除,于是
class SharedStateBase : public std::enable_shared_from_this<SharedStateBase> {
friend class Schedular;
public:
virtual ~SharedStateBase() = default;
};
template<class T>
class SharedState : public SharedStateBase {
// ...
};
Schedular里面也得补充存储shared_ptr<SharedStateBase>的东西
class Schedular {
// ...
// private
void post_call_state(std::shared_ptr<SharedStateBase> state) {
pending_states.push_back(std::move(state));
}
std::vector<std::shared_ptr<SharedStateBase>> pending_states;
};
下面,就轮到Schedular的poll函数在每帧调用被post过来的SharedStateBase了
class Schedular {
// ...
// public
void poll() {
size_t sz = pending_states.size();
for (size_t i = 0; i != sz; i++) {
auto state = std::move(pending_states[i]);
state->invoke_all_callback();
}
pending_states.erase(pending_states.begin(), pending_states.begin()+sz);
}
};
- 之所以这里使用下标循环,是因为迭代过程中还可能有callback继续往pending_states里面新增元素
- 之所以不用while !empty()而是预先获取size,是为了避免调度是callback内无限post callback导致无限循环
- 之所以调用前先move出来,是为了避免调用callback期间callback继续往pending_states里面新增元素导致容器扩容,内容物失效
这里对state调用了invoke_all_callback,显然这是一个虚函数,需要给SharedState补上
class SharedStateBase : public std::enable_shared_from_this<SharedStateBase> {
// ...
// private
virtual void invoke_all_callback() = 0;
};
template<class T>
class SharedState : public SharedStateBase {
// ...
// private
virtual void invoke_all_callback() override {
callback_posted = false;
size_t sz = finish_callbacks.size();
for (size_t i = 0; i != sz; i++) {
auto v = std::move(finish_callbacks[i]);
v(value);
}
finish_callbacks.erase(finish_callbacks.begin(), finish_callbacks.begin()+sz);
}
};
这里在invoke_all_callbacks里面,使用了和上面schedular poll里面类似的代码结构,可以让本帧callback里新增的callback post到下一帧调用
好了,全部功能都齐了,下面可以测试了,完整的代码在文章最后贴出
测试
如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢
int main() {
Schedular schedular;
Promise<int> promise(schedular);
Future<int> future = promise.get_future();
std::cout << "future get\n";
promise.set_result(10);
std::cout << "promise result set\n";
future.add_finish_callback([](int v) {
std::cout << "callback 1 got result " << v << "\n";
});
std::cout << "future callback add\n";
std::cout << "tick 1\n";
schedular.poll();
std::cout << "tick 2\n";
future.add_finish_callback([](int v) {
std::cout << "callback 2 got result " << v << "\n";
});
std::cout << "future callback 2 add\n";
schedular.poll();
std::cout << "\n";
Promise<double> promise2(schedular);
promise2.set_result(12.34);
std::cout << "promise result2 set\n";
Future<double> future2 = promise2.get_future();
std::cout << "future2 get\n";
future2.add_finish_callback([&](double v) {
std::cout << "future2 callback 1 got result" << v << "\n";
future2.add_finish_callback([](double v) {
std::cout << "future2 callback 2 got result" << v << "\n";
});
std::cout << "future2 callback 2 add inside callback\n";
});
std::cout << "future2 callback add\n";
std::cout << "tick 3\n";
schedular.poll();
std::cout << "tick 4\n";
schedular.poll();
}
输出
future get
promise result set
future callback add
tick 1
callback 1 got result 10
tick 2
future callback 2 add
callback 2 got result 10
promise result2 set
future2 get
future2 callback add
tick 3
future2 callback 1 got result12.34
future2 callback 2 add inside callback
tick 4
future2 callback 2 got result12.34
怎么样,是不是很简单呢,赶紧自己回家造一个吧!
附录
如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢
完整代码#include <vector>
#include <memory>
#include <iostream>
#include <functional>
template<class T>
class Future;
template<class T>
class Promise;
class Schedular;
class SharedStateBase : public std::enable_shared_from_this<SharedStateBase> {
friend class Schedular;
public:
virtual ~SharedStateBase() = default;
private:
virtual void invoke_all_callback() = 0;
};
template<class T>
class SharedState : public SharedStateBase {
friend class Future<T>;
friend class Promise<T>;
public:
SharedState(Schedular& schedular)
: schedular(&schedular)
{}
SharedState(const SharedState&) = delete;
SharedState(SharedState&&) = delete;
SharedState& operator=(const SharedState&) = delete;
SharedState& operator=(SharedState&&) = delete;
private:
template<class U>
void set(U&& v) {
if (settled) {
return;
}
settled = true;
value = std::forward<U>(v);
post_all_callbacks();
}
T& get() { return value; }
void add_finish_callback(std::function<void(T&)> callback) {
finish_callbacks.push_back(std::move(callback));
post_all_callbacks();
}
void post_all_callbacks();
virtual void invoke_all_callback() override {
callback_posted = false;
size_t sz = finish_callbacks.size();
for (size_t i = 0; i != sz; i++) {
auto v = std::move(finish_callbacks[i]);
v(value);
}
finish_callbacks.erase(finish_callbacks.begin(), finish_callbacks.begin()+sz);
}
bool has_owner = false;
bool settled = false;
bool callback_posted = false;
Schedular* schedular = nullptr;
T value;
std::vector<std::function<void(T&)>> finish_callbacks;
};
template<class T>
class Promise {
public:
Promise(Schedular& schedular)
: _schedular(&schedular)
, _state(std::make_shared<SharedState<T>>(*_schedular))
{}
Future<T> get_future();
template<class U>
void set_result(U&& value) {
if (_state->settled) {
throw std::invalid_argument("already set result");
}
_state->set(std::forward<U>(value));
}
private:
Schedular* _schedular;
std::shared_ptr<SharedState<T>> _state;
};
template<class T>
class Future {
friend class Promise<T>;
private:
Future(std::shared_ptr<SharedState<T>> state)
: _state(std::move(state))
{
}
public:
void add_finish_callback(std::function<void(T&)> callback) {
_state->add_finish_callback(std::move(callback));
}
private:
std::shared_ptr<SharedState<T>> _state;
};
template<class T>
Future<T> Promise<T>::get_future() {
return Future<T>(_state);
}
class Schedular {
template<class T>
friend class SharedState;
public:
Schedular() = default;
Schedular(Schedular&&) = delete;
Schedular(const Schedular&) = delete;
Schedular& operator=(Schedular&&) = delete;
Schedular& operator=(const Schedular&) = delete;
void poll() {
size_t sz = pending_states.size();
for (size_t i = 0; i != sz; i++) {
auto state = std::move(pending_states[i]);
state->invoke_all_callback();
}
pending_states.erase(pending_states.begin(), pending_states.begin()+sz);
}
private:
void post_call_state(std::shared_ptr<SharedStateBase> state) {
pending_states.push_back(std::move(state));
}
std::vector<std::shared_ptr<SharedStateBase>> pending_states;
};
template<class T>
void SharedState<T>::post_all_callbacks() {
if (callback_posted) {
return;
}
callback_posted = true;
schedular->post_call_state(shared_from_this());
}
int main() {
Schedular schedular;
Promise<int> promise(schedular);
Future<int> future = promise.get_future();
std::cout << "future get\n";
promise.set_result(10);
std::cout << "promise result set\n";
future.add_finish_callback([](int v) {
std::cout << "callback 1 got result " << v << "\n";
});
std::cout << "future callback add\n";
std::cout << "tick 1\n";
schedular.poll();
std::cout << "tick 2\n";
future.add_finish_callback([](int v) {
std::cout << "callback 2 got result " << v << "\n";
});
std::cout << "future callback 2 add\n";
schedular.poll();
std::cout << "\n";
Promise<double> promise2(schedular);
promise2.set_result(12.34);
std::cout << "promise result2 set\n";
Future<double> future2 = promise2.get_future();
std::cout << "future2 get\n";
future2.add_finish_callback([&](double v) {
std::cout << "future2 callback 1 got result" << v << "\n";
future2.add_finish_callback([](double v) {
std::cout << "future2 callback 2 got result" << v << "\n";
});
std::cout << "future2 callback 2 add inside callback\n";
});
std::cout << "future2 callback add\n";
std::cout << "tick 3\n";
schedular.poll();
std::cout << "tick 4\n";
schedular.poll();
}