C++ 20 协程之Task实现
本文继续之前的文章,来实现我们的Task类。
实现基础的Promise和Task
首先实现基础的Promise类
template <typename T> class Promise {
public:
auto get_return_object() { return std::coroutine_handle<Promise>::from_promise(*this); }
auto initial_suspend() noexcept { return std::suspend_always{}; }
auto final_suspend() noexcept { return std::suspend_always{}; }
auto unhandled_exception() { m_ptr = std::current_exception(); }
auto return_value(int value) { m_value = value; }
T result() { return m_value; }
private:
T m_value;
std::exception_ptr m_ptr;
};
这是一个非常基本的Promise类,有两个私有变量,m_value和m_ptr,分别表示结果的值和可能的异常。同时包含五个必须实现的函数和一个用于返回结果的reuslt()函数。
接下来让我们来实现Task类,这里先直接给出一部分代码。
template <typename T> class Task {
public:
using promise_type = Promise<T>;
using coroutine_handle = std::coroutine_handle<promise_type>;
Task();
Task(coroutine_handle coroutine) : m_coroutine{coroutine} {}
coroutine_handle handle() { return m_coroutine; }
promise_type promise() { return m_coroutine.promise(); }
private:
std::coroutine_handle<promise_type> m_coroutine;
};
接着,让我们来做个小测试。代码如下:
Task<int> hello() { co_return 1; }
int main() {
std::cout << "开始构建hello\n";
auto t = hello();
std::cout << "构建hello完成\n";
while (!t.handle().done()) {
t.handle().resume();
std::cout << "main得到hello中的结果为: " << t.promise().result() << '\n';
}
return 0;
}
// 输出结果为:
/*
开始构建hello
构建hello完成
main得到hello中的结果为: 1
*/
这样我们就得到了一个最基本的Task类。为了能够在Task类中co_await其他协程,我们还需要实现对co_await的重载。先写我们的需求测试代码。
Task<int> world() {
std::cout << "world\n";
co_return 41;
}
Task<double> hello() {
int i = co_await world();
std::cout << std::format("hello得到world结果为 {}\n", i);
co_return i + 1.99;
}
int main() {
std::cout << "开始构建hello\n";
auto t = hello();
std::cout << "构建hello完成\n";
while (!t.done()) {
t.resume();
std::cout << "main得到hello中的结果为: " << t.promise().result() << '\n';
}
return 0;
}
这段代码并不能通过编译,int i = co_await world(); 会报错Task<int> 不具有成员”await_ready"。这是因为world()也就是Task对象并不是可等待对象,因此不能被co_await。为了让代码能通过编译,我们需要重载operator co_await()。
co_await 之旅
struct CurrentAwaiter {
auto await_ready() { return false; }
void await_suspend(std::coroutine_handle<> h) {}
auto await_resume() { return m_current.promise().result(); }
std::coroutine_handle<promise_type> m_current;
};
auto operator co_await() { return CurrentAwaiter{m_coroutine}; }
现在让我们仔细看这段代码,operator co_await()返回我们自定义的等待体CurrentAwaiter。当在hello()协程中co_await world()的时候,程序会调用operator co_await(),然后构建一个等待体对象,此时因为await_ready()返回false,所以hello()协程会被挂起,然后程序继续执行await_suspend(),因此这里我们什么都不做,函数返回void,因此程序的执行权会返回到恢复hello()的函数,也就是main()函数中,由于此时还没有执行到co_return i + 1; hello().promise().reuslt()并没有结果,main()函数会得到0。然后由于hello()并没有执行结束,main()会继续恢复hello()协程,此时会发现我们找不到hello()中co_await 的 world()协程了!我们并没有任何可以恢复它的代码逻辑,因此我们将来也永远不会执行到world()协程,hello()得到world()的结果只能为初始值0,然后hello()结束co_return i+1,main()函数得到1。
输出如下:
开始构建hello
构建hello完成
main得到hello中的结果为: 0
hello得到world结果为 0
main得到hello中的结果为: 1
按常规想法,我们有必要保存被co_await 协程的句柄,然后在后续恢复执行它。为了保存被co_await 协程的句柄,我们首先在Promise类中新增std::coroutine_handle<> m_handle{};用于存放协程的句柄,然后增加Setter ,也就是void continuation(std::coroutine_handle<> h) { m_handle = h; },然后看我们修改后的CurrentAwaiter代码。
struct CurrentAwaiter {
auto await_ready() { return false; }
std::coroutine_handle<> await_suspend(std::coroutine_handle<> h) {
m_current.promise().continuation(h);
return m_current;
}
auto await_resume() { return m_current.promise().result(); }
std::coroutine_handle<promise_type> m_current;
};
auto operator co_await() { return CurrentAwaiter{m_coroutine}; }
其实这里有一个地方不容易分辨,也就是CurrentAwaiter中的m_current和await_suspend()参数中的h,分别指代哪个协程,这是重中之重。首先m_current来自operator co_await()。也就是说,当我们co_await world()的时候,m_coroutine也就是world协程的句柄,传给了CurrentAwaiter中的m_current。而await_suspend()参数中的h其实保存的是调用co_await()所在当前协程的句柄,也就是说参数中的句柄是hello()的句柄。因此,其实我们调用co_await就已经能拿到world()协程的句柄,只不过在修改代码之前我们并没有恢复它。为了能恢复执行world()协程,我们只需要在await_suspend()中return m_current就好了。那如果是这样的话,我们为什么还要保存hello()的句柄呢?我们明明已经做到恢复world()协程了。是的,这没错,但是如果不保存hello()的句柄,你会发现在world()执行后,程序没有返回到hello()协程中,反而是直接回到了main()函数,因为程序的执行权会返回到恢复hello()的函数。这不是我们想要的,我们想要的是co_await world()后,先执行完world(),然后返回执行hello()。所以其实我们需要保存的是hello()协程的句柄。 为了能恢复执行hello(),我们在final_suspend()上做文章。
当world()执行完成之后,由于我们将hello()协程的句柄保存到了world()的Promise中,也就是说world()现在知道自己的上级协程,因此在world()执行完成后,如果有上级协程,我们就恢复执行上级协程。代码如下:
struct FinalAwaiter {
auto await_ready() noexcept { return false; }
std::coroutine_handle<> await_suspend(std::coroutine_handle<>) noexcept {
if (handle)
return handle;
else
return std::noop_coroutine();
}
auto await_resume() noexcept {}
std::coroutine_handle<> handle
};
auto final_suspend() noexcept { return FinalAwaiter{m_handle}; }
这样我们就做到现在所有的事情了,我们之前的需求测试代码也终于能正常编译运行。
目前完整代码如下:
#include <coroutine>
#include <exception>
#include <format>
#include <iostream>
#include <stdexcept>
#include <utility>
#include <variant>
template <typename T> class Task;
template <typename T> class Promise {
struct FinalAwaiter {
auto await_ready() noexcept { return false; }
std::coroutine_handle<> await_suspend(std::coroutine_handle<>) noexcept {
if (handle)
return handle;
else
return std::noop_coroutine();
}
auto await_resume() noexcept {}
std::coroutine_handle<> handle;
};
public:
using variant_type = std::variant<T, std::exception_ptr>;
auto initial_suspend() noexcept { return std::suspend_always{}; }
auto final_suspend() noexcept { return FinalAwaiter{m_handle}; }
auto get_return_object() noexcept {
return Task<T>{std::coroutine_handle<Promise>::from_promise(*this)};
}
auto unhandled_exception() { new (&m_storage) variant_type(std::current_exception()); }
auto return_value(T value) { m_storage.template emplace<T>(value); }
T result() { return std::get<T>(m_storage); }
void continuation(std::coroutine_handle<> h) { m_handle = h; }
private:
std::coroutine_handle<> m_handle{};
std::variant<T, std::exception_ptr> m_storage{};
};
template <typename T> class Task {
public:
using promise_type = Promise<T>;
using coroutine_handle = std::coroutine_handle<promise_type>;
Task() : m_coroutine{nullptr} {};
explicit Task(coroutine_handle coroutine) : m_coroutine{coroutine} {}
auto done() { return !m_coroutine || m_coroutine.done(); }
void resume() {
if (!m_coroutine.done())
m_coroutine.resume();
}
struct CurrentAwaiter {
auto await_ready() { return false; }
std::coroutine_handle<> await_suspend(std::coroutine_handle<> h) {
m_current.promise().continuation(h);
return m_current;
}
auto await_resume() { return m_current.promise().result(); }
std::coroutine_handle<promise_type> m_current;
};
auto operator co_await() { return CurrentAwaiter{m_coroutine}; }
coroutine_handle handle() { return m_coroutine; }
promise_type promise() { return m_coroutine.promise(); }
private:
std::coroutine_handle<promise_type> m_coroutine;
};
Task<int> world() {
std::cout << "world\n";
co_return 41;
}
Task<double> hello() {
int i = co_await world();
std::cout << std::format("hello得到world结果为 {}\n", i);
co_return i + 1.99;
}
int main() {
std::cout << "开始构建hello\n";
auto t = hello();
std::cout << "构建hello完成\n";
while (!t.done()) {
t.resume();
std::cout << "main得到hello中的结果为: " << t.promise().result() << '\n';
}
return 0;
}
运行结果如下:
开始构建hello
构建hello完成
world
hello得到world结果为 41
main得到hello中的结果为: 42.99
我们还对代码进行了优化,引入了variant,确保T和std::exception_ptr只会有一个,节约了代码空间。
特化void版本的Task
细心的读者可能还注意到了,要是我们的Task没有返回值该怎么办呢?没错,我们还必须特化void版本的Task。void特化版本不需要存储值了,也就不需要使用variant,保存出现的异常即可,因此代码非常简单。
template <> class Promise<void> {
struct FinalAwaiter {
auto await_ready() noexcept { return false; }
std::coroutine_handle<> await_suspend(std::coroutine_handle<> h) noexcept {
if (handle)
return handle;
else
return std::noop_coroutine();
}
auto await_resume() noexcept {}
std::coroutine_handle<> handle;
};
public:
using coroutine_handle = std::coroutine_handle<Promise<void>>;
auto initial_suspend() noexcept { return std::suspend_always{}; }
auto final_suspend() noexcept { return FinalAwaiter{m_handle}; }
auto get_return_object();
auto unhandled_exception() { m_ptr = std::current_exception(); }
auto return_void() {}
auto result() {
if (m_ptr) {
std::rethrow_exception(m_ptr);
}
}
void continuation(std::coroutine_handle<> h) { m_handle = h; }
private:
std::coroutine_handle<> m_handle{};
std::exception_ptr m_ptr{};
};
现在终于大功告成了吧,还有没有可以优化的地方呢?可能有读者又发现了,我们的void特化版本又把FinalAwaiter给实现了一遍,当然还有其他一些函数,都重复实现了,这显然不符合C++的哲学。因此在这里可以用上继承,把这些公共的部分放在基类中,减少代码的书写。
struct PromiseBase {
struct FinalAwaiter {
auto await_ready() noexcept { return false; }
template <typename promise_type>
std::coroutine_handle<> await_suspend(std::coroutine_handle<promise_type> h) noexcept {
if (handle)
return handle;
else
return std::noop_coroutine();
}
auto await_resume() noexcept {}
std::coroutine_handle<> handle;
};
auto initial_suspend() noexcept { return std::suspend_always{}; }
auto final_suspend() noexcept { return FinalAwaiter{m_handle}; }
void continuation(std::coroutine_handle<> h) { m_handle = h; }
protected:
std::coroutine_handle<> m_handle{};
};
最后,Promise还有Task都是不能复制的,我们必须删除它的拷贝构造和赋值运算符,并考虑是否需要重载移动构造和移动赋值。最后,我们为Task添加几个工具函数。
最终成果
最终完整的代码的如下:
#include <coroutine>
#include <exception>
#include <format>
#include <iostream>
#include <stdexcept>
#include <utility>
#include <variant>
template <typename T = void> class Task;
namespace detail {
struct PromiseBase {
struct FinalAwaiter {
auto await_ready() noexcept { return false; }
template <typename promise_type>
std::coroutine_handle<> await_suspend(std::coroutine_handle<promise_type> h) noexcept {
if (handle)
return handle;
else
return std::noop_coroutine();
}
auto await_resume() noexcept {}
std::coroutine_handle<> handle;
};
auto initial_suspend() noexcept { return std::suspend_always{}; }
auto final_suspend() noexcept { return FinalAwaiter{m_handle}; }
void continuation(std::coroutine_handle<> h) { m_handle = h; }
protected:
std::coroutine_handle<> m_handle{};
};
template <typename T> class Promise : public PromiseBase {
public:
using variant_type = std::variant<T, std::exception_ptr>;
using coroutine_handle = std::coroutine_handle<Promise<T>>;
Promise() noexcept {}
Promise(const Promise &) = delete;
Promise(Promise &&other) = delete;
Promise &operator=(const Promise &) = delete;
Promise &operator=(Promise &&other) = delete;
~Promise() = default;
auto get_return_object() noexcept { return Task<T>{coroutine_handle::from_promise(*this)}; }
auto unhandled_exception() { new (&m_storage) variant_type(std::current_exception()); }
auto return_value(T value) { m_storage.template emplace<T>(value); }
T result() {
if (std::holds_alternative<T>(m_storage)) {
return std::get<T>(m_storage);
} else if (std::holds_alternative<std::exception_ptr>(m_storage)) {
std::rethrow_exception(std::get<std::exception_ptr>(m_storage));
} else {
throw std::runtime_error{"The return value was never set"};
}
}
private:
std::variant<T, std::exception_ptr> m_storage{};
};
template <> class Promise<void> : public PromiseBase {
public:
using coroutine_handle = std::coroutine_handle<Promise<void>>;
auto get_return_object() noexcept;
auto unhandled_exception() { m_ptr = std::current_exception(); }
auto return_void() {}
auto result() {
if (m_ptr) {
std::rethrow_exception(m_ptr);
}
}
private:
std::exception_ptr m_ptr{};
};
} // namespace detaill
template <typename T> class Task {
public:
using promise_type = detail::Promise<T>;
using coroutine_handle = std::coroutine_handle<promise_type>;
Task() : m_coroutine{nullptr} {};
explicit Task(coroutine_handle coroutine) : m_coroutine{coroutine} {}
Task(const Task &) = delete;
Task(Task &&other) noexcept : m_coroutine(std::exchange(other.m_coroutine, nullptr)) {}
~Task() {
if (m_coroutine != nullptr) {
m_coroutine.destroy();
}
}
Task& operator=(const Task&) = delete;
Task& operator=(Task&& other) noexcept {
if (std::addressof(other) != this) {
if (m_coroutine != nullptr) {
m_coroutine.destroy();
}
m_coroutine = std::exchange(other.m_coroutine, nullptr);
}
return *this;
}
bool done() { return !m_coroutine || m_coroutine.done(); }
bool resume() {
if (!m_coroutine.done())
m_coroutine.resume();
return !m_coroutine.done();
}
bool destroy() {
if (m_coroutine != nullptr) {
m_coroutine.destroy();
m_coroutine = nullptr;
return true;
}
return false;
}
struct CurrentAwaiter {
auto await_ready() { return false; }
std::coroutine_handle<> await_suspend(std::coroutine_handle<> h) {
m_current.promise().continuation(h);
return m_current;
}
auto await_resume() { return m_current.promise().result(); }
std::coroutine_handle<promise_type> m_current;
};
auto operator co_await() { return CurrentAwaiter{m_coroutine}; }
coroutine_handle handle() { return m_coroutine; }
promise_type &promise() { return m_coroutine.promise(); }
private:
std::coroutine_handle<promise_type> m_coroutine;
};
namespace detail {
auto Promise<void>::get_return_object() noexcept {
return Task<>{coroutine_handle::from_promise(*this)};
}
} // namespace detaill
Task<> voidTask() {
std::cout << "这是一个没有返回值的Task\n";
co_return;
}
Task<int> world() {
std::cout << "world\n";
co_await voidTask();
std::cout << "等到voidTask返回\n";
co_return 41;
}
Task<double> hello() {
int i = co_await world();
std::cout << std::format("hello得到world结果为 {}\n", i);
co_return i + 1.99;
}
int main() {
std::cout << "开始构建hello\n";
auto t = hello();
std::cout << "构建hello完成\n";
while (!t.done()) {
t.resume();
std::cout << "main得到hello中的结果为: " << t.promise().result() << '\n';
}
return 0;
}
运行结果为:
开始构建hello
构建hello完成
world
这是一个没有返回值的Task
等到voidTask返回
hello得到world结果为 41
main得到hello中的结果为: 42.99
总结
我们实现了一个基本的Task,但是想要发挥出协程真正的本领,这些还远远不够。我们还需要进一步地封装实现其他的工具类。