Rust使用axum结合Actor模型实现异步发送SSE

1,310 阅读7分钟

SSE服务器推送技术是一种从服务器向客户端浏览器推送消息的技术,是h5规范中的其中一个部分。SSE是基于http协议的,一般来说短连接是没办法做实时推送的,但是当服务端回复客户端的响应头类型为event-stream事件流时,说明服务端发送的是流数据,因此两端会保持长连接。

好,假设我们建立了一个SSE长连接,那么何时服务端会向客户端发送数据就变得不可预测了,这样就不可避免的涉及到了异步,如何异步地发送响应数据成为了一个问题。

tokio核心库下的async-stream就能很好的解决这个问题,它提供了稳定生成异步元素流的能力。库本身的代码并不多,它提供两个宏,stream!和try_stream!,区别呢就是try_stream!多了Result,如果想要详细了解源码,可以参考一下这篇文章Rust Async: async-stream源码分析

当然stream!只是提供了执行异步代码块的能力,要实现并发传值,还需要一个消息传递模型,Actor模型的思想是这样的,我们说每个实体都是一个actor,actor自己保存有状态,actor与actor之间通过行为来约束状态,行为有发送和接收,而发送和接收是通过邮箱来实现,邮箱通过消息队列先进先出的原则来发送和接收数据。这种模型隔离性好,可以实现无锁异步。而在rust中通过channel就可以构造出这种模型。

另外我们还需要axum来实现http server,axum是基于tower-http生态构建出来的,和tokio同属一个家族,用axum再合适不过了。

下面直接上案例。

案例

main.rs

type TokioUnboundedSender<T> = tokio::sync::mpsc::UnboundedSender<T>;

#[tokio::main(flavor = "multi_thread")]
async fn main() {
    // Use an unbounded channel to handle buffering and flushing of messages
    // to the event source...
    let (collect_tx, collect_rx) = tokio::sync::mpsc::unbounded_channel::<MyEvent>();
    let collect_rx = tokio_stream::wrappers::UnboundedReceiverStream::new(collect_rx);
    
    tokio::task::spawn(async move { sessions::process(collect_rx).await });
    let app = Router::new()
        .route("/sse", get(services::sse_handler))
        .route("/", get(|| async { "Hello, World!" }))
        .layer(Extension(collect_tx.clone()));
    let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
    axum::Server::bind(&addr)
        .serve(app.into_make_service())
        .await
        .unwrap();
}

使用tokio的多生产单消费通道,并将发送者和接收者转化为流发送者和流接收者。这里的发送者和接收者就是Actor模型中的actor。

layer(Extension(collect_tx.clone()))相当于我为每一个handler添加了一个提取器,提取器可以放一些共享对象,比如数据库连接池、通道发送者等等。这里我传入了一个发送者。

services.rs

pub(super) async fn sse_handler(
    Extension(collect_tx): Extension<TokioUnboundedSender<MyEvent>>,
) -> Sse<impl futures::stream::Stream<Item = Result<Event, Infallible>>> {
    let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
    let _ = collect_tx.send(MyEvent::SSE(tx));
    
    //...中间代码
    Sse::new(stream).keep_alive(
        axum::response::sse::KeepAlive::new()
            .interval(Duration::from_secs(15))
            .text("keep-alive-text"),
    )
}

handler的返回值是一个axum::response::sse::Sse结构体,我们看下Sse的关联函数new()

axum::response::sse::Sse

pub fn new(stream: S) -> Self  
where  
S: TryStream<Ok = Event> + Send + 'static,  
S::Error: Into<BoxError>,

它要求传入的泛型参数是实现了TryStream的,并且可以在线程间传递所有权。 TryStream trait继承了Stream trait,并且futures核心库已经帮我们实现了这个条件,也就是为任意实现了Stream trait的对象实现TryStream。

impl<S, T, E> TryStream for S
where
    S: ?Sized + Stream<Item = Result<T, E>>,
{
    type Ok = T;
    type Error = E;
    fn try_poll_next(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Self::Ok, Self::Error>>> {
        self.poll_next(cx)
    }
}

Stream trait

#[must_use = "streams do nothing unless polled"]
pub trait Stream {
    /// Values yielded by the stream.
    type Item;
    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>;
    #[inline]
    fn size_hint(&self) -> (usize, Option<usize>) {
        (0, None)
    }
}

如果Future<Output= T>是T的异步版本,那么Stream<Item = T>就是Iterator<Item = T>的异步版本。Stream表示的是一系列产生值的事件异步的传递给调用方。它是以Future为模型的,Future有poll(),类比Iterator的next(),Stream就是poll_next()。

因此我们实际上只要传递一个实现了Stream trait的对象给new函数就行。

接着我们来看中间部分的代码,就是从开始就铺垫的async_stream::stream!

中间代码

    use tokio_stream::StreamExt as _;
    let stream = async_stream::stream! {
        loop {
            let signal = rx.recv().await;
            yield signal
        };
    }
    .map(|signal| {
        if let Some(signal) = signal {
            let event = Event::default()
                .event(signal.event.clone())
                .data(signal.data);
            println!("发送event: {:?}", event);
            event
        } else {
            println!("发送event: None");
            Event::default()
        }
    })
    .map(Ok)
    .throttle(Duration::from_secs(3));

我们直接看stream!宏的源码来分析

#[macro_export]
macro_rules! stream {
    ($($tt:tt)*) => {
        $crate::__private::stream_inner!(($crate) $($tt)*)
    }
}

这里看出它直接调用了私有模块__private下的stream_inner!,又是个宏,接着进去。

/// The first token tree in the stream must be a group containing the path to the `async-stream`
/// crate.
#[proc_macro]
#[doc(hidden)]
pub fn stream_inner(input: TokenStream) -> TokenStream {
    let (crate_path, mut stmts) = match parse_input(input) {
        Ok(x) => x,
        Err(e) => return e.to_compile_error().into(),
    };
    let mut scrub = Scrub::new(false, &crate_path);
    for stmt in &mut stmts {
        scrub.visit_stmt_mut(stmt);
    }
    let dummy_yield = if scrub.has_yielded {
        None
    } else {
        Some(quote!(if false {
            __yield_tx.send(()).await;
        }))
    };
    quote!({
        let (mut __yield_tx, __yield_rx) = unsafe { #crate_path::__private::yielder::pair() };
        #crate_path::__private::AsyncStream::new(__yield_rx, async move {
            #dummy_yield
            #(#stmts)*
        })
    })
    .into()
}

直接看最后一段,可以看到它实际上调用了AsyncStream::new(),返回一个AsyncStream struct。 而且AsyncStream实现了Stream trait,这样就正好满足了我们的需求。

rust是不支持yield关键字的,stream!的yield语法是通过过程宏实现的,它在语法树中搜索包含yield $expr 的实例,然后将他们替换为sender.send($expr).await。

宏展开后就是

AsyncStream::new(sender, async move {
    loop {
        let signal = rx.recv().await;
        sender.send(signal).await 
    };
})

async块内部是一个循环,通道接收者不断接收从process的事件循环处理过的数据,然后yield出去。

map组合器和标准库的map差不多,只不过它是对future的stream做包装、映射。map返回出来的结构体Map也实现了Stream,因此它是内联于poll_next的调用而执行的,每次poll_next得到的值它都做一次包装再返回出去。

impl<St, F, T> Stream for Map<St, F>
where
    St: Stream,
    F: FnMut(St::Item) -> T,
{
    type Item = T;
    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
        self.as_mut()
            .project()
            .stream
            .poll_next(cx)
            .map(|opt| opt.map(|x| (self.as_mut().project().f)(x)))
    }
    fn size_hint(&self) -> (usize, Option<usize>) {
        self.stream.size_hint()
    }
}

sessions.rs

use crate::TokioUnboundedSender;
#[derive(Default, Debug, Clone)]
pub(crate) struct Signal {
    pub(crate) event: String,
    pub(crate) data: String,
}

impl Signal {
    pub(crate) fn new(event: String, data: String) -> Self {
        Signal { event, data }
    }
}

Signal是自定义的包装结构,可以根据业务需要修改。

pub(super) enum MyEvent {
    SSE(TokioUnboundedSender<Signal>),
}

MyEvent是一个事件枚举,这里简单写了一项,SSE成员是一个发送者,接下来就利用这个发送者给handler发送事件信号。

pub(super) async fn process(
    mut collect_rx: tokio_stream::wrappers::UnboundedReceiverStream<MyEvent>,
) {
    use tokio_stream::StreamExt as _;
    while let Some(sessions) = collect_rx.next().await {
        match sessions {
            MyEvent::SSE(sender) => {
                for i in 0..5 {
                    let _ = sender.send(Signal::new("send".to_string(), i.to_string()));
                }
            }
        }
    }
}

process方法内部是一个事件循环,从主线程传入的唯一接收者collect_rx,在此不断接收handler里的发送者发送的MyEvent。对MyEvent进行模式匹配,我们的业务就跟随匹配到的项做相应的逻辑。这里我只是简单用for循环发送了5次数据包来模拟无规律的响应。

运行

打开浏览器访问地址,在控制台输入脚本。

var eventSource = new EventSource('/sse');

eventSource.addEventListener("send",(e) => {
    console.log("event: ", e.type, ", data: ", e.data);
})

eventSource.onmessage = function(e) {
    console.log('event: ', e.type, ", data: ", e.data);
}

可以看到,不断能接收到服务端发送的数据 localhost_3000 - Google Chrome 2023-05-24 10-54-45.gif

services.rs - sse_demo - Visual Studio Code [管理员] 2023-05-24 10-56-40.gif

完整代码

main.rs

mod services;
mod sessions;
use std::net::SocketAddr;
use axum::{routing::get, Extension, Router};
use sessions::MyEvent;

type TokioUnboundedSender<T> = tokio::sync::mpsc::UnboundedSender<T>;

#[tokio::main(flavor = "multi_thread")]
async fn main() {
    // Use an unbounded channel to handle buffering and flushing of messages
    // to the event source...
    let (collect_tx, collect_rx) = tokio::sync::mpsc::unbounded_channel::<MyEvent>();
    let collect_rx = tokio_stream::wrappers::UnboundedReceiverStream::new(collect_rx);

    tokio::task::spawn(async move { sessions::process(collect_rx).await });
    let app = Router::new()
        .route("/sse", get(services::sse_handler))
        .route("/", get(|| async { "Hello, World!" }))
        .layer(Extension(collect_tx.clone()));
    let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
    axum::Server::bind(&addr)
        .serve(app.into_make_service())
        .await
        .unwrap();
}

services.rs

use std::{convert::Infallible, sync::Arc, time::Duration};
use axum::{
    response::{sse::Event, Sse},
    Extension,
};
use crate::{sessions::MyEvent, TokioUnboundedSender};

pub(super) async fn sse_handler(
    Extension(collect_tx): Extension<Arc<TokioUnboundedSender<MyEvent>>>,
) -> Sse<impl futures::stream::Stream<Item = Result<Event, Infallible>>> {
    let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
    use tokio_stream::StreamExt as _;
    let _ = collect_tx.send(MyEvent::SSE(tx));

    let stream = async_stream::stream! {
        loop {
            let signal = rx.recv().await;
            yield signal
        };
    }
    .map(|signal| {
        let event = if let Some(signal) = signal {
            Event::default()
                .event(signal.event.clone())
                .data(signal.data)
        } else {
            Event::default().data(format!("None"))
        };
        println!("发送event: {:?}", event);
        event
    })
    .map(Ok)
    .throttle(Duration::from_secs(1));
    Sse::new(stream)
        .keep_alive(axum::response::sse::KeepAlive::new().interval(Duration::from_secs(15)))
}

sessions.rs

use crate::TokioUnboundedSender;
#[derive(Default, Debug, Clone)]
pub(crate) struct Signal {
    pub(crate) event: String,
    pub(crate) data: String,
}

impl Signal {
    pub(crate) fn new(event: String, data: String) -> Self {
        Signal { event, data }
    }
}
pub(super) enum MyEvent {
    SSE(TokioUnboundedSender<Signal>),
}

pub(super) async fn process(
    mut collect_rx: tokio_stream::wrappers::UnboundedReceiverStream<MyEvent>,
) {
    use tokio_stream::StreamExt as _;

    while let Some(sessions) = collect_rx.next().await {
        match sessions {
            MyEvent::SSE(sender) => {
                for i in 0..5 {
                    let _ = sender.send(Signal::new("send".to_string(), i.to_string()));
                }
            }
        }
    }
}

总结

代码中有很多值得优化的点,比如layer中间件传递的是一个被clone过的sender,如果用Arc包一下,它克隆的就只是智能指针,而不是深拷贝,就可以省掉这个消耗。

这个案例其实融合了很多知识点,但是我希望传达的更多的是我们在遇到问题时该如何切入。一开始是选择合适的前后端通信技术,在选定了SSE并且理解它的通信原理后,选择axum作为实现基础,通过方法签名得知参数是以Stream trait为泛型约束,选定满足该条件的async-stream crate。最后根据具体业务,选择合适的并发模型方案。技术固然很重要,但是我们还要学会把技术像链表一样串联起来为我所用的能力。

Rust:axum学习笔记(6) SSE(Server Send Event)服务端推送

Rust Async: async-stream源码分析