C++20协程之Task实现

342 阅读10分钟

C++ 20 协程之Task实现

本文继续之前的文章,来实现我们的Task类。

实现基础的Promise和Task

首先实现基础的Promise类

template <typename T> class Promise {
public:
    auto get_return_object() { return std::coroutine_handle<Promise>::from_promise(*this); }
    auto initial_suspend() noexcept { return std::suspend_always{}; }
    auto final_suspend() noexcept { return std::suspend_always{}; }
    auto unhandled_exception() { m_ptr = std::current_exception(); }
    auto return_value(int value) { m_value = value; }
​
    T result() { return m_value; }
​
private:
    T m_value;
    std::exception_ptr m_ptr;
};

这是一个非常基本的Promise类,有两个私有变量,m_valuem_ptr,分别表示结果的值和可能的异常。同时包含五个必须实现的函数和一个用于返回结果的reuslt()函数。

接下来让我们来实现Task类,这里先直接给出一部分代码。

template <typename T> class Task {
public:
    using promise_type = Promise<T>;
    using coroutine_handle = std::coroutine_handle<promise_type>;
    Task();
    Task(coroutine_handle coroutine) : m_coroutine{coroutine} {}
​
    coroutine_handle handle() { return m_coroutine; }
    promise_type promise() { return m_coroutine.promise(); }
​
private:
    std::coroutine_handle<promise_type> m_coroutine;
};

接着,让我们来做个小测试。代码如下:

Task<int> hello() { co_return 1; }
​
int main() {
    std::cout << "开始构建hello\n";
    auto t = hello();
    std::cout << "构建hello完成\n";
​
    while (!t.handle().done()) {
        t.handle().resume();
        std::cout << "main得到hello中的结果为: " << t.promise().result() << '\n';
    }
    return 0;
}
​
​
// 输出结果为:
/*
开始构建hello
构建hello完成
main得到hello中的结果为: 1
*/

这样我们就得到了一个最基本的Task类。为了能够在Task类中co_await其他协程,我们还需要实现对co_await的重载。先写我们的需求测试代码。

Task<int> world() {
    std::cout << "world\n";
    co_return 41;
}
​
Task<double> hello() {
    int i = co_await world();
    std::cout << std::format("hello得到world结果为 {}\n", i);
    co_return i + 1.99;
}
​
int main() {
    std::cout << "开始构建hello\n";
    auto t = hello();
    std::cout << "构建hello完成\n";
​
    while (!t.done()) {
        t.resume();
        std::cout << "main得到hello中的结果为: " << t.promise().result() << '\n';
    }
    return 0;
}

这段代码并不能通过编译,int i = co_await world(); 会报错Task<int> 不具有成员”await_ready"。这是因为world()也就是Task对象并不是可等待对象,因此不能被co_await。为了让代码能通过编译,我们需要重载operator co_await()

co_await 之旅

struct CurrentAwaiter {
    auto await_ready() { return false; }
    void await_suspend(std::coroutine_handle<> h) {}
    auto await_resume() { return m_current.promise().result(); }
​
    std::coroutine_handle<promise_type> m_current;
};
​
auto operator co_await() { return CurrentAwaiter{m_coroutine}; }

现在让我们仔细看这段代码,operator co_await()返回我们自定义的等待体CurrentAwaiter。当在hello()协程中co_await world()的时候,程序会调用operator co_await(),然后构建一个等待体对象,此时因为await_ready()返回false,所以hello()协程会被挂起,然后程序继续执行await_suspend(),因此这里我们什么都不做,函数返回void,因此程序的执行权会返回到恢复hello()的函数,也就是main()函数中,由于此时还没有执行到co_return i + 1; hello().promise().reuslt()并没有结果,main()函数会得到0。然后由于hello()并没有执行结束,main()会继续恢复hello()协程,此时会发现我们找不到hello()co_awaitworld()协程了!我们并没有任何可以恢复它的代码逻辑,因此我们将来也永远不会执行到world()协程,hello()得到world()的结果只能为初始值0,然后hello()结束co_return i+1,main()函数得到1

输出如下:

开始构建hello
构建hello完成
main得到hello中的结果为: 0
hello得到world结果为 0
main得到hello中的结果为: 1

按常规想法,我们有必要保存被co_await 协程的句柄,然后在后续恢复执行它。为了保存被co_await 协程的句柄,我们首先在Promise类中新增std::coroutine_handle<> m_handle{};用于存放协程的句柄,然后增加Setter ,也就是void continuation(std::coroutine_handle<> h) { m_handle = h; },然后看我们修改后的CurrentAwaiter代码。

struct CurrentAwaiter {
    auto await_ready() { return false; }
    std::coroutine_handle<> await_suspend(std::coroutine_handle<> h) {
        m_current.promise().continuation(h);
        return m_current;
    }
    auto await_resume() { return m_current.promise().result(); }
​
    std::coroutine_handle<promise_type> m_current;
};
​
auto operator co_await() { return CurrentAwaiter{m_coroutine}; }

其实这里有一个地方不容易分辨,也就是CurrentAwaiter中的m_currentawait_suspend()参数中的h,分别指代哪个协程,这是重中之重。首先m_current来自operator co_await()。也就是说,当我们co_await world()的时候,m_coroutine也就是world协程的句柄,传给了CurrentAwaiter中的m_current。而await_suspend()参数中的h其实保存的是调用co_await()所在当前协程的句柄,也就是说参数中的句柄是hello()的句柄。因此,其实我们调用co_await就已经能拿到world()协程的句柄,只不过在修改代码之前我们并没有恢复它。为了能恢复执行world()协程,我们只需要在await_suspend()return m_current就好了。那如果是这样的话,我们为什么还要保存hello()的句柄呢?我们明明已经做到恢复world()协程了。是的,这没错,但是如果不保存hello()的句柄,你会发现在world()执行后,程序没有返回到hello()协程中,反而是直接回到了main()函数,因为程序的执行权会返回到恢复hello()的函数。这不是我们想要的,我们想要的是co_await world()后,先执行完world(),然后返回执行hello()。所以其实我们需要保存的是hello()协程的句柄。 为了能恢复执行hello(),我们在final_suspend()上做文章。

world()执行完成之后,由于我们将hello()协程的句柄保存到了world()Promise中,也就是说world()现在知道自己的上级协程,因此在world()执行完成后,如果有上级协程,我们就恢复执行上级协程。代码如下:

struct FinalAwaiter {
    auto await_ready() noexcept { return false; }
    std::coroutine_handle<> await_suspend(std::coroutine_handle<>) noexcept {
        if (handle)
            return handle;
        else
            return std::noop_coroutine();
    }
    auto await_resume() noexcept {}
​
    std::coroutine_handle<> handle
};
​
auto final_suspend() noexcept { return FinalAwaiter{m_handle}; }

这样我们就做到现在所有的事情了,我们之前的需求测试代码也终于能正常编译运行。

目前完整代码如下:

#include <coroutine>
#include <exception>
#include <format>
#include <iostream>
#include <stdexcept>
#include <utility>
#include <variant>template <typename T> class Task;
​
template <typename T> class Promise {
    struct FinalAwaiter {
        auto await_ready() noexcept { return false; }
        std::coroutine_handle<> await_suspend(std::coroutine_handle<>) noexcept {
            if (handle)
                return handle;
            else
                return std::noop_coroutine();
        }
        auto await_resume() noexcept {}
​
        std::coroutine_handle<> handle;
    };
​
public:
    using variant_type = std::variant<T, std::exception_ptr>;
​
    auto initial_suspend() noexcept { return std::suspend_always{}; }
​
    auto final_suspend() noexcept { return FinalAwaiter{m_handle}; }
    auto get_return_object() noexcept {
        return Task<T>{std::coroutine_handle<Promise>::from_promise(*this)};
    }
    auto unhandled_exception() { new (&m_storage) variant_type(std::current_exception()); }
    auto return_value(T value) { m_storage.template emplace<T>(value); }
​
    T result() { return std::get<T>(m_storage); }
​
    void continuation(std::coroutine_handle<> h) { m_handle = h; }
​
private:
    std::coroutine_handle<> m_handle{};
    std::variant<T, std::exception_ptr> m_storage{};
};
​
template <typename T> class Task {
public:
    using promise_type = Promise<T>;
    using coroutine_handle = std::coroutine_handle<promise_type>;
    Task() : m_coroutine{nullptr} {};
    explicit Task(coroutine_handle coroutine) : m_coroutine{coroutine} {}
​
    auto done() { return !m_coroutine || m_coroutine.done(); }
    void resume() {
        if (!m_coroutine.done())
            m_coroutine.resume();
    }
​
    struct CurrentAwaiter {
        auto await_ready() { return false; }
        std::coroutine_handle<> await_suspend(std::coroutine_handle<> h) {
            m_current.promise().continuation(h);
            return m_current;
        }
        auto await_resume() { return m_current.promise().result(); }
​
        std::coroutine_handle<promise_type> m_current;
    };
​
    auto operator co_await() { return CurrentAwaiter{m_coroutine}; }
​
    coroutine_handle handle() { return m_coroutine; }
    promise_type promise() { return m_coroutine.promise(); }
​
private:
    std::coroutine_handle<promise_type> m_coroutine;
};
​
Task<int> world() {
    std::cout << "world\n";
    co_return 41;
}
​
Task<double> hello() {
    int i = co_await world();
    std::cout << std::format("hello得到world结果为 {}\n", i);
    co_return i + 1.99;
}
​
int main() {
    std::cout << "开始构建hello\n";
    auto t = hello();
    std::cout << "构建hello完成\n";
​
    while (!t.done()) {
        t.resume();
        std::cout << "main得到hello中的结果为: " << t.promise().result() << '\n';
    }
    return 0;
}

运行结果如下:

开始构建hello
构建hello完成
world
hello得到world结果为 41
main得到hello中的结果为: 42.99

我们还对代码进行了优化,引入了variant,确保Tstd::exception_ptr只会有一个,节约了代码空间。

特化void版本的Task

细心的读者可能还注意到了,要是我们的Task没有返回值该怎么办呢?没错,我们还必须特化void版本的Taskvoid特化版本不需要存储值了,也就不需要使用variant,保存出现的异常即可,因此代码非常简单。

template <> class Promise<void> {
    struct FinalAwaiter {
        auto await_ready() noexcept { return false; }
        std::coroutine_handle<> await_suspend(std::coroutine_handle<> h) noexcept {
            if (handle)
                return handle;
            else
                return std::noop_coroutine();
        }
        auto await_resume() noexcept {}
​
        std::coroutine_handle<> handle;
    };
​
public:
    using coroutine_handle = std::coroutine_handle<Promise<void>>;
    auto initial_suspend() noexcept { return std::suspend_always{}; }
​
    auto final_suspend() noexcept { return FinalAwaiter{m_handle}; }
    auto get_return_object();
    auto unhandled_exception() { m_ptr = std::current_exception(); }
    auto return_void() {}
    auto result() {
        if (m_ptr) {
            std::rethrow_exception(m_ptr);
        }
    }
    void continuation(std::coroutine_handle<> h) { m_handle = h; }
​
private:
    std::coroutine_handle<> m_handle{};
    std::exception_ptr m_ptr{};
};

现在终于大功告成了吧,还有没有可以优化的地方呢?可能有读者又发现了,我们的void特化版本又把FinalAwaiter给实现了一遍,当然还有其他一些函数,都重复实现了,这显然不符合C++的哲学。因此在这里可以用上继承,把这些公共的部分放在基类中,减少代码的书写。

struct PromiseBase {
    struct FinalAwaiter {
        auto await_ready() noexcept { return false; }
        template <typename promise_type>
        std::coroutine_handle<> await_suspend(std::coroutine_handle<promise_type> h) noexcept {
            if (handle)
                return handle;
            else
                return std::noop_coroutine();
        }
        auto await_resume() noexcept {}
​
        std::coroutine_handle<> handle;
    };
    auto initial_suspend() noexcept { return std::suspend_always{}; }
    auto final_suspend() noexcept { return FinalAwaiter{m_handle}; }
​
    void continuation(std::coroutine_handle<> h) { m_handle = h; }
​
protected:
    std::coroutine_handle<> m_handle{};
};

最后,Promise还有Task都是不能复制的,我们必须删除它的拷贝构造和赋值运算符,并考虑是否需要重载移动构造和移动赋值。最后,我们为Task添加几个工具函数。

最终成果

最终完整的代码的如下:

#include <coroutine>
#include <exception>
#include <format>
#include <iostream>
#include <stdexcept>
#include <utility>
#include <variant>template <typename T = void> class Task;
​
namespace detail {
    struct PromiseBase {
        struct FinalAwaiter {
            auto await_ready() noexcept { return false; }
            template <typename promise_type>
            std::coroutine_handle<> await_suspend(std::coroutine_handle<promise_type> h) noexcept {
                if (handle)
                    return handle;
                else
                    return std::noop_coroutine();
            }
            auto await_resume() noexcept {}
​
            std::coroutine_handle<> handle;
        };
        auto initial_suspend() noexcept { return std::suspend_always{}; }
        auto final_suspend() noexcept { return FinalAwaiter{m_handle}; }
​
        void continuation(std::coroutine_handle<> h) { m_handle = h; }
​
    protected:
        std::coroutine_handle<> m_handle{};
    };
​
    template <typename T> class Promise : public PromiseBase {
​
    public:
        using variant_type = std::variant<T, std::exception_ptr>;
        using coroutine_handle = std::coroutine_handle<Promise<T>>;
​
        Promise() noexcept {}
        Promise(const Promise &) = delete;
        Promise(Promise &&other) = delete;
        Promise &operator=(const Promise &) = delete;
        Promise &operator=(Promise &&other) = delete;
        ~Promise() = default;
​
        auto get_return_object() noexcept { return Task<T>{coroutine_handle::from_promise(*this)}; }
        auto unhandled_exception() { new (&m_storage) variant_type(std::current_exception()); }
        auto return_value(T value) { m_storage.template emplace<T>(value); }
​
        T result() {
            if (std::holds_alternative<T>(m_storage)) {
                return std::get<T>(m_storage);
            } else if (std::holds_alternative<std::exception_ptr>(m_storage)) {
                std::rethrow_exception(std::get<std::exception_ptr>(m_storage));
            } else {
                throw std::runtime_error{"The return value was never set"};
            }
        }
​
    private:
        std::variant<T, std::exception_ptr> m_storage{};
    };
​
    template <> class Promise<void> : public PromiseBase {
​
    public:
        using coroutine_handle = std::coroutine_handle<Promise<void>>;
​
        auto get_return_object() noexcept;
        auto unhandled_exception() { m_ptr = std::current_exception(); }
        auto return_void() {}
        auto result() {
            if (m_ptr) {
                std::rethrow_exception(m_ptr);
            }
        }
​
    private:
        std::exception_ptr m_ptr{};
    };
} // namespace detailltemplate <typename T> class Task {
public:
    using promise_type = detail::Promise<T>;
    using coroutine_handle = std::coroutine_handle<promise_type>;
    Task() : m_coroutine{nullptr} {};
    explicit Task(coroutine_handle coroutine) : m_coroutine{coroutine} {}
​
    Task(const Task &) = delete;
    Task(Task &&other) noexcept : m_coroutine(std::exchange(other.m_coroutine, nullptr)) {}
​
    ~Task() {
        if (m_coroutine != nullptr) {
            m_coroutine.destroy();
        }
    }
    
    Task& operator=(const Task&) = delete;
    Task& operator=(Task&& other) noexcept {
        if (std::addressof(other) != this) {
            if (m_coroutine != nullptr) {
                m_coroutine.destroy();
            }
​
            m_coroutine = std::exchange(other.m_coroutine, nullptr);
        }
​
        return *this;
    }
​
    bool done() { return !m_coroutine || m_coroutine.done(); }
    
    bool resume() {
        if (!m_coroutine.done())
            m_coroutine.resume();
        return !m_coroutine.done();
    }
    
    bool destroy() {
        if (m_coroutine != nullptr) {
            m_coroutine.destroy();
            m_coroutine = nullptr;
            return true;
        }
​
        return false;
    }
​
    struct CurrentAwaiter {
        auto await_ready() { return false; }
        std::coroutine_handle<> await_suspend(std::coroutine_handle<> h) {
            m_current.promise().continuation(h);
            return m_current;
        }
        auto await_resume() { return m_current.promise().result(); }
​
        std::coroutine_handle<promise_type> m_current;
    };
​
    auto operator co_await() { return CurrentAwaiter{m_coroutine}; }
​
    coroutine_handle handle() { return m_coroutine; }
    promise_type &promise() { return m_coroutine.promise(); }
​
private:
    std::coroutine_handle<promise_type> m_coroutine;
};
​
namespace detail {
    auto Promise<void>::get_return_object() noexcept {
        return Task<>{coroutine_handle::from_promise(*this)};
    }
} // namespace detaillTask<> voidTask() {
    std::cout << "这是一个没有返回值的Task\n";
    co_return;
}
​
Task<int> world() {
    std::cout << "world\n";
    co_await voidTask();
    std::cout << "等到voidTask返回\n";
    co_return 41;
}
​
Task<double> hello() {
    int i = co_await world();
    std::cout << std::format("hello得到world结果为 {}\n", i);
    co_return i + 1.99;
}
​
int main() {
    std::cout << "开始构建hello\n";
    auto t = hello();
    std::cout << "构建hello完成\n";
​
    while (!t.done()) {
        t.resume();
        std::cout << "main得到hello中的结果为: " << t.promise().result() << '\n';
    }
    return 0;
}

运行结果为:

开始构建hello
构建hello完成
world
这是一个没有返回值的Task
等到voidTask返回
hello得到world结果为 41
main得到hello中的结果为: 42.99

总结

我们实现了一个基本的Task,但是想要发挥出协程真正的本领,这些还远远不够。我们还需要进一步地封装实现其他的工具类。