【零基础 Rust 入门 05】并发 - Async

58 阅读9分钟

这是【零基础 Rust 入门】系列的第 5 章。本系列由前端技术专家零弌分享。想要探索前端技术的无限可能,就请关注我们吧!🤗

并发模型

类型描述开销适合任务数
OS thread操作系统线程,代码几乎不需要大的改动即可迁移到多线程模型,但是多线程保持同步比较困难。
Event-driven事件驱动,一般通过 callback 编程,复杂度高,编写控制流程的成本较高。
Coroutines协程,和线程类似,但是开销小,可以支持大的并发数。抹掉了底层细节。
The actor modelactor model,通过消息通信来实现并发,类似于分布式系统。但是对于编写控制流程、重试等功能比较困难。
Async代码改动大,能支持的任务数多。

Rust 中 Async 的特点:

  • Futures 是惰性的,只有 pool 了(调用了 .await )才会执行。
    • JS 中的 Promise 是立即执行的,无论是否 await。
  • Async 是 0 开销的,没有堆上内存分配和 dynamic dispatch 的开销。(高性能)
    • JS 中的 Promise 是一个对象,有自己的状态,因此有额外内存的开销,JS 中几乎所有的方法调用都是 dynamic dispatch(先不考虑 V8 中的 JIT)。
  • 没有内置的 runtime,所有的 runtime 都是社区的,比如 tokio、async-std。(此条我不认为是优势,现在 rust 中的 async runtime 互相不兼容,选择成本高,内置的优势更大)
  • 支持单\多线程 runtime,对照上一条。是一条中立的特性,有好有坏。

和上一章的 Thread 对比,Thread 会带来额外的创建的开销和切换的开销,因此 Async 比较适合任务数更多的任务,比如说 IO。

但是 Async 和 Rust 并不互斥,可以同时使用。

Async/Await

简单看一下 async 和 await 的语法。

// block_on 阻塞当前线程,直到 future 结束
use futures::executor::block_on;

async fn learn_and_sing() {
    let song = learn_song().await;
    sing_song(song).await;
}

async fn async_main() {
    let f1 = learn_and_sing();
    let f2 = dance();

    // 等待 f1 和 f2 结束
    futures::join!(f1, f2);
}

fn main() {
    block_on(async_main());
}

通过 async 即可定义一个函数,其本质上会返回一个 Future。

async fn foo() {
    // ....
}

通过 .await 即可等待一个 Future 结束。

async fn bar() {
    foo().await;
}

Future

Future 是一个 trait,我们可以看一下其标准定义。

pub trait Future {
    type Output;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>;
}
  • Output: Associated Type(在范型里提过),定义了这个 Future 的返回值

  • poll: Future 方法的核心,runtime 会来调用 poll 获取 Future 方法的状态

    • self: self 注意是一个 Pin 类型,后续会解释为什么是 Pin 的

    • cx: 当前 Future 上下文,只能用来获取 waker。

Poll

熟悉的枚举又来了,Pollpoll 方法的返回值,代表了 Future 的情况。

pub enum Poll<T> {
    Ready(T),
    Pending,
}
  • Ready(T): 代表 Future 已经计算好,T 就是 Future 的 Output。
  • Pending: 代表 Future 还在计算。

Context

通过 Context 可以获取 waker

pub struct Context<'a>

pub fn from_waker(waker: &'a Waker) -> Context<'a>

pub fn waker(&self) -> &'a Waker

Waker

通知 Runtime,Future 已经计算完成,可以再次调用 poll 方法。

pub struct Waker
pub fn wake(self)
pub fn wake_by_ref(&self)
pub fn will_wake(&self, other: &Waker) -> bool
  • wake: 通知 Runtime 再次调用 poll 方法,多次调用可能会被合并为一次。
  • wake_by_ref: 持有 waker 引用时,不需要消耗掉 waker 本身,不需要额外调用一次 clone。
  • will_wake: 优化向,减少 wake 的调用次数。

0 开销

Join

通过维护一个状态机,即可实现 0 开销的等待所有的 Future 结束。下列代码中没有任何 heap allocation。

pub struct Join<FutureA, FutureB> {
    a: Option<FutureA>,
    b: Option<FutureB>,
}

impl<FutureA, FutureB> SimpleFuture for Join<FutureA, FutureB>
where
    FutureA: SimpleFuture<Output = ()>,
    FutureB: SimpleFuture<Output = ()>,
{
    type Output = ();
    fn poll(&mut self, wake: fn()) -> Poll<Self::Output> {
        if let Some(a) = &mut self.a {
            if let Poll::Ready(()) = a.poll(wake) {
                self.a.take();
            }
        }

        if let Some(b) = &mut self.b {
            if let Poll::Ready(()) = b.poll(wake) {
                self.b.take();
            }
        }

        if self.a.is_none() && self.b.is_none() {
            Poll::Ready(())
        } else {
            Poll::Pending
        }
    }
}

实现一个简单的定时器及其 Runtime

实现 Future

use std::{
    future::Future,
    pin::Pin,
    sync::{Arc, Mutex},
    task::{Context, Poll, Waker},
    thread,
    time::Duration,
};


pub struct TimerFuture {
    shared_state: Arc<Mutex<SharedState>>,
}

/// 共享的状态
struct SharedState {
    /// 定时器是否已经结束
    completed: bool,

    /// 唤醒状态
    waker: Option<Waker>,
}

impl Future for TimerFuture {
    type Output = ();
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        // 确认是否已经结束
        let mut shared_state = self.shared_state.lock().unwrap();
        if shared_state.completed {
            Poll::Ready(())
        } else {
            // 每次更新状态中的 waker,因为 Future 可以在 runtime
            // 中的多个 executor move,为了避免 wake 错 executor
            // 因此每次都 clone
            // 注意上面提过 wake_by_ref 来避免 clone
            // 可以通过 will_wake 来优化,这里为了简单才直接用了 clone。
            shared_state.waker = Some(cx.waker().clone());
            Poll::Pending
        }
    }
}

impl TimerFuture {
    pub fn new(duration: Duration) -> Self {
        let shared_state = Arc::new(Mutex::new(SharedState {
            completed: false,
            waker: None,
        }));

        let thread_shared_state = shared_state.clone();
        // 启动线程
        thread::spawn(move || {
            thread::sleep(duration);
            let mut shared_state = thread_shared_state.lock().unwrap();
            // 设置完成
            shared_state.completed = true;
            // 如果 waker 存在 wake 一下
            if let Some(waker) = shared_state.waker.take() {
                waker.wake()
            }
        });

        TimerFuture { shared_state }
    }
}

实现 Executor

use futures::{
    future::{BoxFuture, FutureExt},
    task::{waker_ref, ArcWake},
};
use std::{
    future::Future,
    sync::mpsc::{sync_channel, Receiver, SyncSender},
    sync::{Arc, Mutex},
    task::Context,
    time::Duration,
};
use timer_future::TimerFuture;

/// Executor 运行 Task
struct Executor {
    // 对应下面的 task_sender
    ready_queue: Receiver<Arc<Task>>,
}

/// spawn futures 向 sender 发送
#[derive(Clone)]
struct Spawner {
    task_sender: SyncSender<Arc<Task>>,
}

/// A future that can reschedule itself to be polled by an `Executor`.
struct Task {
    /// 这里简单写了,可以用 UnsafeCell 来避免 Mutext 的使用
    future: Mutex<Option<BoxFuture<'static, ()>>>,

    /// 向 sender 发送自身
    task_sender: SyncSender<Arc<Task>>,
}

fn new_executor_and_spawner() -> (Executor, Spawner) {
    const MAX_QUEUED_TASKS: usize = 10_000;
    let (task_sender, ready_queue) = sync_channel(MAX_QUEUED_TASKS);
    (Executor { ready_queue }, Spawner { task_sender })
}

impl Spawner {
    fn spawn(&self, future: impl Future<Output = ()> + 'static + Send) {
        let future = future.boxed();
        let task = Arc::new(Task {
            future: Mutex::new(Some(future)),
            task_sender: self.task_sender.clone(),
        });
        self.task_sender.send(task).expect("too many tasks queued");
    }
}

impl ArcWake for Task {
    fn wake_by_ref(arc_self: &Arc<Self>) {
        // 发送自身
        let cloned = arc_self.clone();
        arc_self
            .task_sender
            .send(cloned)
            .expect("too many tasks queued");
    }
}

impl Executor {
    fn run(&self) {
        while let Ok(task) = self.ready_queue.recv() {
            // 获取 Future
            let mut future_slot = task.future.lock().unwrap();
            if let Some(mut future) = future_slot.take() {
                // 创建 Context
                let waker = waker_ref(&task);
                let context = &mut Context::from_waker(&waker);

                // 如果 Future 还是 pending 的,将 Future 放回 Task
                if future.as_mut().poll(context).is_pending() {
                    *future_slot = Some(future);
                }
            }
        }
    }
}

main

fn main() {
    let (executor, spawner) = new_executor_and_spawner();

    spawner.spawn(async {
        println!("howdy!");
        // 等待两秒
        TimerFuture::new(Duration::new(2, 0)).await;
        println!("done!");
    });

    // drop 代表队列结束
    drop(spawner);

    // 运行 executor
    executor.run();
}

深入 Async

Async Fn/Async Block

// `foo()` returns a type that implements `Future<Output = u8>`.
// `foo().await` will result in a value of type `u8`.
async fn foo() -> u8 { 5 }

fn bar() -> impl Future<Output = u8> {
    // This `async` block results in a type that implements
    // `Future<Output = u8>`.
    async {
        let x: u8 = foo().await;
        x + 5
    }
}

foo: 返回了 Future<Output = u8>

bar: 不是一个 async fn,使用 async block expression 返回了 impl Future<Output = u8>

生命周期

async fn foo(x: &u8) -> u8 { *x }

fn foo<'a>(x: &'a u8) -> impl Future<Output = u8> + 'a {
    async move { *x }
}

两个方法等价,下面的 foo 相当于完整的定义。注意看返回值。impl Future<Output = u8> + 'a, 返回的 Future 加上了入参的生命周期。

因此入参的生命周期必须在 .await 调用之后结束。

fn bad() -> impl Future<Output = u8> {
    // x 在 await 之前就释放了,编译报错
    let x = 5;
    borrow_x(&x)
}

fn good() -> impl Future<Output = u8> {
    async {
        // x 将会存活到 await 之后
        let x = 5;
        borrow_x(&x).await
    }
}

move

async block 和闭包一样,都允许 move 参数,这个 my_string 的 ownership 被 async block 获取了。

fn move_block() -> impl Future<Output = ()> {
    let my_string = "foo".to_string();
    async move {
        // ...
        println!("{my_string}");
    }
}

Pin

参考上面实现简单定时器及其 runtime,一个 Future 可能被在不同的 Task 中 move。如果他是一个自引用的类型,他的内存就会指向错误的地址。

#[derive(Debug)]
struct Test {
    a: String,
    b: *const String,
}

image.png

swap 之前

a: test1, b: test1
a: test2, b: test2

swap 之后

a: test2, b: test1
a: test1, b: test2

预期的是

a: test1, b: test1
a: test2, b: test2

如何解决?使用 Pin trait 和 PhantomPinned。编译器会失败,因为 swap 方法会要求入参是实现 Unpin 的。而 Pin 类型是 !Unpin 的。

use std::pin::Pin;
use std::marker::PhantomPinned;

#[derive(Debug)]
struct Test {
    a: String,
    b: *const String,
    _marker: PhantomPinned,
}

impl Test {
    fn init(self: Pin<&mut Self>) {
        let self_ptr: *const String = &self.a;
        let this = unsafe { self.get_unchecked_mut() };
        this.b = self_ptr;
    }
}

Join/Select

使用 join 来同时等待多个 Future 结束。

use futures::join;

async fn get_book_and_music() -> (Book, Music) {
    let book_fut = get_book();
    let music_fut = get_music();
    join!(book_fut, music_fut)
}

使用 try_join 在某个 Future 抛出 error 时停止等待。

use futures::try_join;

async fn get_book() -> Result<Book, String> { /* ... */ Ok(Book) }
async fn get_music() -> Result<Music, String> { /* ... */ Ok(Music) }

async fn get_book_and_music() -> Result<(Book, Music), String> {
    let book_fut = get_book();
    let music_fut = get_music();
    try_join!(book_fut, music_fut)
}

使用 select 来控制每个 Future 结束之后的行为。

use futures::{future, select};

async fn count() {
    let mut a_fut = future::ready(4);
    let mut b_fut = future::ready(6);
    let mut total = 0;

    loop {
        select! {
            // a_fut 结束, a 为 a_fut 的 Output
            a = a_fut => total += a,
            // b_fut 结束, b 为 b_fut 的 Output
            b = b_fut => total += b,
            // 所有 future 结束
            complete => break,
            // 如果没有 Future 结束走到 default 分支
            default => unreachable!(), // never runs (futures are ready, then complete)
        };
    }
    assert_eq!(total, 10);
}

Stream

和 Iterator 类似,用 map,filter, fold 和 try。但是 stream 和 for 还不能使用,只能用 while + next 的配合。

async fn sum_with_next(mut stream: Pin<&mut dyn Stream<Item = i32>>) -> i32 {
    use futures::stream::StreamExt; // for `next`
    let mut sum = 0;
    while let Some(item) = stream.next().await {
        sum += item;
    }
    sum
}

async fn sum_with_try_next(
    mut stream: Pin<&mut dyn Stream<Item = Result<i32, io::Error>>>,
) -> Result<i32, io::Error> {
    use futures::stream::TryStreamExt; // for `try_next`
    let mut sum = 0;
    while let Some(item) = stream.try_next().await? {
        sum += item;
    }
    Ok(sum)
}

FusedFuture

select 在 Future 结束之后不能再次 poll,而 FusedFuture 可以告诉 select 已经 ready 使其可以再次 select。这样就能在 loop 中使用了。

use futures::{
    stream::{Stream, StreamExt, FusedStream},
    select,
};

async fn add_two_streams(
    mut s1: impl Stream<Item = u8> + FusedStream + Unpin,
    mut s2: impl Stream<Item = u8> + FusedStream + Unpin,
) -> u8 {
    let mut total = 0;

    loop {
        let item = select! {
            x = s1.next() => x,
            x = s2.next() => x,
            complete => break,
        };
        if let Some(next_num) = item {
            total += next_num;
        }
    }

    total
}

FuturesUnordered

将大量的 Futures 收集在 FuturesUnordered 中,select_next_some可以用这个方法来获取已经完成的 Future。

Fuse::terminated() 可以先设置一个 Fuse,在 loop 中去填充。

use futures::{
    future::{Fuse, FusedFuture, FutureExt},
    stream::{FusedStream, FuturesUnordered, Stream, StreamExt},
    pin_mut,
    select,
};

async fn get_new_num() -> u8 { /* ... */ 5 }

async fn run_on_new_num(_: u8) -> u8 { /* ... */ 5 }

async fn run_loop(
    mut interval_timer: impl Stream<Item = ()> + FusedStream + Unpin,
    starting_num: u8,
) {
    let mut run_on_new_num_futs = FuturesUnordered::new();
    run_on_new_num_futs.push(run_on_new_num(starting_num));
    let get_new_num_fut = Fuse::terminated();
    pin_mut!(get_new_num_fut);
    loop {
        select! {
            () = interval_timer.select_next_some() => {
                // The timer has elapsed. Start a new `get_new_num_fut`
                // if one was not already running.
                if get_new_num_fut.is_terminated() {
                    get_new_num_fut.set(get_new_num().fuse());
                }
            },
            new_num = get_new_num_fut => {
                // A new number has arrived -- start a new `run_on_new_num_fut`.
                run_on_new_num_futs.push(run_on_new_num(new_num));
            },
            // Run the `run_on_new_num_futs` and check if any have completed
            res = run_on_new_num_futs.select_next_some() => {
                println!("run_on_new_num_fut returned {:?}", res);
            },
            // panic if everything completed, since the `interval_timer` should
            // keep yielding values indefinitely.
            complete => panic!("`interval_timer` completed unexpectedly"),
        }
    }
}

课后作业

使用 async/await 实现一个 HTTP server。支持返回 hello.html 和 404.html。

需要解析 HTTP Request 中的 path,如果和 hello.html 匹配,则返回 hello.html 否则返回 404.html

<!DOCTYPE html>
<html lang="en">
  <head>
    <meta charset="utf-8">
    <title>Hello!</title>
  </head>
  <body>
    <h1>Hello!</h1>
    <p>Hi from Rust</p>
  </body>
</html>
<!DOCTYPE html>
<html lang="en">
  <head>
    <meta charset="utf-8">
    <title>Hello!</title>
  </head>
  <body>
    <h1>Oops!</h1>
    <p>Sorry, I don't know what you're asking for.</p>
  </body>
</html>