C++协程: 封装通用异步任务Task

1,412 阅读10分钟

协程主要用来降低异步任务的编写复杂度,避免回调地狱,异步任务各种各样,但最终目标都是对一个结果的获取。

实现目标

为了方便介绍后续的内容,我们需要再定义一个类型Task来作为协程的返回值。Task类型可以用来封装任务返回结果的异步行为(持续返回值的情况可能更适合使用序列生成器)。

实现的效果如下:

Task<int> simple_task2() {
    // sleep 1 秒
    using namespace std::chrono_literals;
    std::this_thread::sleep_for(1s);
    
    co_return 2;
}

Task<int> simple_task3() {
    //sleep 2 秒
    using namespace std::chrono_literals;
    std::this_thread::sleep_for(2s);
    
    co_return 3;
}

Task<int> simple_task() {
    // result2 == 2
    auto result2 = co_await simple_task2();
    
    // result3 == 3
    auto result3 = co_await simple_task3();
    
    co_return 1 + result2 + result3;
}

定义Task<ResultType>为返回值类型的协程,并且可以在协程内部使用co_await来等待其他Task的执行。

外部非协程内的函数当中访问Task的结果时,我们可以通过回调或者同步阻塞调用这两种方式来实现:

int main() {
    auto simpleTask = simple_task();

    // 异步方式
    simpleTask.then([](int i) {
        ... // i == 6
    }).catching([](std::exception &e) {
        ...
    });

    // 同步方式
    try {
        auto i = simpleTask.get_result();
        ... // i == 6
    } catch (std::exception &e) {
        ...
    }

    return 0;
    
}

按照这个效果,我们大致可以分析得到:

  1. 需要一个结果类型来承载正常返回和异常抛出的情况。
  2. 需要为Task定义相应的promise_type类型来支持co_returnco_await
  3. Task实现获取结果的阻塞函数get_result或者用于获取返回值的回调then以及用于获取抛出的异常的回调catching.

结果类型的定义

描述Task正常返回的结果和抛出的异常,只需要定义一个持有二者的类型即可:

#include <exception>

template<typename T>
struct Result {
    
  // 初始化为默认值
  explicit Result() = default;

  // 当 Task 正常返回时用结果初始化 Result
  explicit Result(T &&value) : _value(value) {}

  // 当 Task 抛异常时用异常初始化 Result
  explicit Result(std::exception_ptr &&exception_ptr) : _exception_ptr(exception_ptr) {}

  // 读取结果,有异常则抛出异常
  T get_or_throw() {
    if (_exception_ptr) {
      std::rethrow_exception(_exception_ptr);
    }
    return _value;
  }

 private:
  T _value{};
  std::exception_ptr _exception_ptr;

};

其中,Result的模板参数T对应于Task的返回值类型。有了这个结果类型,我们就可以很方便地在需要读取结果的时候调用get_or_throw

promise_type的定义

promise_type的定义是最为重要的一部分。

基本结构

基于前面几篇文章的基础,我们可以轻松地给出它的基本结构:

template<typename ResultType>
struct TaskPromise {
  // 协程立即执行
  std::suspend_always initial_suspend() { return {}; }
  
  // 执行结束后挂起,等待外部销毁,该逻辑与前面的 Generator 类似
  std::suspend_always final_suspend() noexcept { return {}; }
  
  // 构造协程的返回值对象 Task
  Task<ResultType> get_return_object() {
    return Task{std::coroutine_handle<TaskPromise>::from_promise(*this)};
  }
  
  void unhandled_exception() {
    // 将异常存入 result
    result = Result<ResultType>(std::current_exception());
  }
  
  void return_value(ResultType value) {
    // 将返回值存入result, 对应于协程内部的`co_return value`
    result = Result<ResultType>(std::move(value));
  }
  
  private:
    // 使用std::optional 可以区分协程是否执行完成
    std::optional<Result<ResultType>> result;
    
};

await_transform

光有这些还不够,我们还需要为Task添加co_await的支持。这里我们有两个选择:

  1. Task实现co_await运算符
  2. promise_type当中定义await_transform

从效果上来看,二者都可以实现对co_await的支持。但区别在于,await_transformpromise_type的内部函数,可以直接访问到promise内部的状态;同时,await_transform的定义也会限制协程内部对于其他类型的co_await的支持,将协程内部的挂起行为更好的管控起来,方便后续我们做统一的线程调度。因此此处我们采用await_transform来为Task提供co_await支持:

template<typename ResultType>
struct TaskPromise{
  ...
  
  // 注意此处的模板参数
  template<typename _ResultType>
  TaskAwaiter<_ResultType> await_transform(Task<_ResultType> &&task) {
    return TaskAwaiter<_ResultType>(std::move(task));
  }
  
  
  ...
};

代码很简单,返回了一个TaskAwaiter对象。不过再次请大家注意,这里存在来嗯个Task,一个是TaskPromise对应的Task,一个是co_await表达式的操作数Task,后者是await_transform的参数。

下面是TaskAwaiter的定义:

template<typename R>
struct TaskAwait {
  explicit TaskAwaiter(Task<R> &&task) noexpect
    : task(std::move(task)) {}
    
  TaskAwaiter(TaskAwaiter &&completion) noexcept
    : task(std::exchange(completion.task, {})) {}
    
  TaskAwaiter(TaskAwaiter &) = delete;
  
  TaskAwaiter &operator=(TaskAwaiter &) = delete;
  
  constexpr bool await_ready() const noexcept {
    return false;
  }
  
  void await_suspend(std::coroutine_handle<> handle) noexcept {
    // 当task执行完之后调用resume
    task.finally([handle]() {
      handle.resume();
    });
  }
  
  // 协程恢复执行时,被等待的 Task已经执行完,调用get_result来获取结果
  R await_resume() noexcept {
    return task.get_result();
  }
  
  private:
    Task<R> task;
    
};

当一个Task实例被co_await时,意味着它在co_await表达式返回之前已经执行完毕,当co_await表达式返回时,Task的结果也就被取到,Task 实例在后续就没有意义了。因此TaskAwaiter的构造器当中接收Task&&,防止co_await表达式之后继续对Task进行操作。

同步阻塞获取结果

为了防止result被外部随意访问,我们特意将其改为私有成员。接下来我们还需要提供相应的方式方便外部访问result

先来查看如何实现同步阻塞的结果返回:

template<typename ResultType>
struct TaskPromise {
  ...
  
  void unhandled_exception() {
    std::lock_guard lock(completion_lock);
    result = Result<ResultType>(std::current_exception());
    // 通知 get_result 当中的wait
    completion.notify_all();
  }
  
  void return_value(ResultType value) {
    std::lock_guard lock(completion_lock);
    result = Result<ResultType>(std::move(value));
    // 通知get_result 当中的 wait
    completion.notify_all();
  }
  
  ResultType get_result() {
    // 如果 result 没有值,说明协程还没有运行完,等待值被写入再返回
    std::unique_lock lock(completion_lock);
    if(!result.has_value()) {
      // 等待写入志之后调用 notify_all
      completion.wait(lock);
    }
    // 如果有值,则直接返回(或者抛出异常)
    return result->get_or_throw()
  }
  
  private:
    std::optional<Result<ResultType>> result;
    
    std::mutex completion_lock;
    std::condition_variable completion;
};

既然要阻塞,就免不了使用锁(mutex)和条件变量(conditon_variable),熟悉它们的读者一定觉得事情不那么简单了:这些工具一般是在多线程并发的环境中使用的。现在这么写也是为了后续应对多线程的场景,有关多线程调度的问题在下一篇文章中介绍。

异步结果回调

异步回调的实现比较复杂些,主要复杂在对于函数的运用。实际上对于回调的支持,主要就是支持回调的注册和回调的调用。根据结果类型的不同,回调又分为返回值的回调或者抛出异常的回调:

template<typename ResultType>
struct TaskPromise {
  ...

  void unhandled_exception() {
    std::lock_guard lock(completion_lock);
    result = Result<ResultType>(std::current_exception());
    completion.notify_all();
    // 调用回调
    notify_callbacks();
  }

  void return_value(ResultType value) {
    std::lock_guard lock(completion_lock);
    result = Result<ResultType>(std::move(value));
    completion.notify_all();
    // 调用回调
    notify_callbacks();
  }

  void on_completed(std::function<void(Result<ResultType>)> &&func) {
    std::unique_lock lock(completion_lock);
    // 加锁判断 result
    if (result.has_value()) {
      // result 已经有值
      auto value = result.value();
      // 解锁之后再调用 func
      lock.unlock();
      func(value);
    } else {
      // 否则添加回调函数,等待调用
      completion_callbacks.push_back(func);
    }
  }

 private:
  ...

  // 回调列表,我们允许对同一个 Task 添加多个回调
  std::list<std::function<void(Result<ResultType>)>> completion_callbacks;
  
  void notify_callbacks() {
    auto value = result.value();
    for (auto &callback : completion_callbacks) {
      callback(value);
    }
    // 调用完成,清空回调
    completion_callbacks.clear();
  }

}

同样地,如果只是在单线城环境内运行协程,这里的异步回调的作用可能并不明显。这里只是先给出定义,待我们后续支持线程调度之后,这些回调支持就会非常有价值了。

Task的实现

现在我们已经实现了最为关键的 promise_type,接下来给出 Task 类型的完整定义。实际上,Task 不过就是个摆设,它的能力大多都是通过用promise_type 来实现的。

template<typename ResultType>
struct Task {

  // 声明 promise_type 为 TaskPromise 类型
  using promise_type = TaskPromise<ResultType>;

  ResultType get_result() {
    return handle.promise().get_result();
  }

  Task &then(std::function<void(ResultType)> &&func) {
    handle.promise().on_completed([func](auto result) {
      try {
        func(result.get_or_throw());
      } catch (std::exception &e) {
        // 忽略异常
      }
    });
    return *this;
  }

  Task &catching(std::function<void(std::exception &)> &&func) {
    handle.promise().on_completed([func](auto result) {
      try {
        // 忽略返回值
        result.get_or_throw();
      } catch (std::exception &e) {
        func(e);
      }
    });
    return *this;
  }

  Task &finally(std::function<void()> &&func) {
    handle.promise().on_completed([func](auto result) { func(); });
    return *this;
  }

  explicit Task(std::coroutine_handle<promise_type> handle) noexcept: handle(handle) {}

  Task(Task &&task) noexcept: handle(std::exchange(task.handle, {})) {}

  Task(Task &) = delete;

  Task &operator=(Task &) = delete;

  ~Task() {
    if (handle) handle.destroy();
  }

 private:
  std::coroutine_handle<promise_type> handle;
};

现在我们完成了 Task 的第一个通用版本的实现,这个版本的实现当中尽管我们对 Task 的结果做了加锁,但考虑到目前我们仍没有提供线程切换的能力,因此这实际上是一个无调度器版本的 Task 实现。

Task的void特化

前面讨论的 Task 有一个作为返回值类型的模板参数 ResultType。实际上有些时候我们只是希望一段任务可以异步执行完,而不关注它的结果,这时候 ResultType 就需要是 void。例如:

Task<void> Producer(Channel<int> &channel) {
  ...
}

但很快你就会发现问题。编译器会告诉你模板实例化错误,因为我们没法用 void 来声明变量;编译器还会告诉你协程体里面如果没有返回值,你应该提供为 promise_type 提供 return_void 函数。

看来情况没有那么简单。C++ 的模板经常会遇到这种需要特化的情况,我们只需要对之前的 Task<ResultType> 版本的定义稍作修改,就可以给出 Task<void> 的版本:

template<>
struct Task<void> {
  // 用 void 作为第一个模板参数实例化 TaskPromise
  using promise_type = TaskPromise<void>;

  // 返回 void
  void get_result() {
    // 因为是 void,因此不用 return
    // 这时这个函数的作用就是阻塞当前线程等待协程执行完成
    handle.promise().get_result();
  }

  // func 的类型参数 void(),注意之前这个模板类型构造器还有个参数 ResultType
  Task &then(std::function<void()> &&func) {
    handle.promise().on_completed([func](auto result) {
      try {
        // 我们也会对 result 做 void 版本的实例化,这里只是检查有没有异常抛出
        result.get_or_throw();
        func();
      } catch (std::exception &e) {
        // ignore.
      }
    });
    return *this;
  }

  Task &catching(std::function<void(std::exception &)> &&func) {
    handle.promise().on_completed([func](auto result) {
      try {
        result.get_or_throw();
      } catch (std::exception &e) {
        func(e);
      }
    });
    return *this;
  }
  ...
};

你会发现变化的只是跟结果相关的部分。相应的,TaskPromise 也需要做出修改:

template<>
struct TaskPromise<void> {
  ...

  // 注意 Task 的模板参数
  Task<void> get_return_object() {
    return Task{std::coroutine_handle<TaskPromise>::from_promise(*this)};
  }

  ...

  // 返回值类型改成 void
  void get_result() {
    ...
    // 不再需要 return
    result->get_or_throw();
  }

  void unhandled_exception() {
    std::lock_guard lock(completion_lock);
    // Result 的模板参数变化
    result = Result<void>(std::current_exception());
    completion.notify_all();
    notify_callbacks();
  }

  // 不再是 return_value 了
  void return_void() {
    std::lock_guard lock(completion_lock);
    result = Result<void>();
    completion.notify_all();
    notify_callbacks();
  }

  // 注意 Result 的模板参数 void 
  void on_completed(std::function<void(Result<void>)> &&func) {
    ... 
  }

 private:
  // 注意 Result 的模板参数 void
  std::optional<Result<void>> result;
  std::list<std::function<void(Result<void>)>> completion_callbacks;

  ...
};

还有 Result 也有对应的 void 实例化版本,其实就是把存储返回值相关的逻辑全部删掉,只保留异常相关的部分:

template<>
struct Result<void> {

  explicit Result() = default;

  explicit Result(std::exception_ptr &&exception_ptr) : _exception_ptr(exception_ptr) {}

  void get_or_throw() {
    if (_exception_ptr) {
      std::rethrow_exception(_exception_ptr);
    }
  }

 private:
  std::exception_ptr _exception_ptr;
};

至此,我们进一步完善了 Task 对不同类型的结果的支持,理论上我们可以使用 Task 来构建各式各样的协程了。

小试牛刀

接下来我们写一个简单的代码运行一下:

Task<int> simple_task2() {
  std::cout << "task 2 start ..." << std::endl;
  using namespace std::chrono_literals;
  std::this_thread::sleep_for(1s);
  std::cout << "task 2 returns after 1s." << std::endl;
  co_return 2;
}

Task<int> simple_task3() {
  std::cout << "in task 3 start ..." << std::endl;
  using namespace std::chrono_literals;
  std::this_thread::sleep_for(2s);
  std::cout << "task 3 returns after 2s." << std::endl;
  ;
  co_return 3;
}

Task<int> simple_task() {
  std::cout << "task start ..." << std::endl;
  auto result2 = co_await simple_task2();
  std::cout << "returns from task2: " << result2 << std::endl;
  auto result3 = co_await simple_task3();
  std::cout << "returns from task3: " << result3 << std::endl;
  co_return 1 + result2 + result3;
}

int main() {
  auto simpleTask = simple_task();
  simpleTask
      .then([](int i) { std::cout << "simple task end: " << i << std::endl; })
      .catching([](std::exception &e) {
        std::cout << "error occurred" << e.what() << std::endl;
      });
  try {
    auto i = simpleTask.get_result();
    std::cout << "simple task end from get: " << i << std::endl;
  } catch (std::exception &e) {
    std::cout << "error: " << e.what() << std::endl;
  }
  return 0;
}

运行结果如下:

task start ...
task 2 start ...
task 2 returns after 1s.
returns from task2: 2
in task 3 start ...
task 3 returns after 2s.
returns from task3: 3
simple task end: 6
simple task end from get: 6

由于我们的任务在执行过程中没有进行任何线程的切换或者协程的调度,因此各个Task的执行实际是穿行的,就如同我们调用普通函数一样。当然这显然不是我们的最终目的,下一篇文章将给Task增加调度器的支持。

总结

本文我们详细介绍了无调度器版本的通用Task 的实现。尽管程序还没实现真正的异步执行,但是已经开始接近了,下一篇文章中将实现真正的异步。