从零到一:用 Rust 和 Tokio 构建高性能 WebSocket 服务器

12 阅读1分钟

引言

在实时应用日益普及的今天,WebSocket 协议已成为构建聊天应用、实时协作工具、在线游戏和金融交易系统的核心技术。虽然 Node.js 和 Go 等语言在 WebSocket 服务器开发中很常见,但 Rust 凭借其卓越的性能和内存安全性,正在成为构建高性能实时服务的理想选择。

本文将带你从零开始,使用 Rust 和 Tokio 异步运行时,构建一个完整的、可扩展的 WebSocket 服务器。我们将不仅关注基础实现,还会深入探讨连接管理、消息广播、错误处理等生产级功能。

为什么选择 Rust 和 Tokio?

Rust 的优势

  • 零成本抽象:高性能无需牺牲安全性
  • 内存安全:编译时保证,无需垃圾回收
  • 并发安全:所有权系统防止数据竞争
  • 丰富的生态系统:Cargo 包管理器提供了大量高质量库

Tokio 的优势

  • 异步运行时:专为 I/O 密集型应用设计
  • 高并发:单线程可处理数十万连接
  • 生态系统完善:与 Rust 异步生态完美集成

项目搭建

首先创建新项目并添加依赖:

cargo new websocket-server
cd websocket-server

编辑 Cargo.toml

[package]
name = "websocket-server"
version = "0.1.0"
edition = "2021"

[dependencies]
tokio = { version = "1.0", features = ["full"] }
tungstenite = "0.20.0"
futures-util = "0.3.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tracing = "0.1.0"
tracing-subscriber = "0.3.0"
arc-swap = "1.0.0"
dashmap = "5.0.0"

核心架构设计

我们的 WebSocket 服务器将包含以下核心组件:

  1. 连接管理器:管理所有活跃连接
  2. 消息路由器:处理消息的接收和分发
  3. 广播系统:向多个客户端发送消息
  4. 错误处理:优雅地处理连接异常

实现连接管理器

use dashmap::DashMap;
use futures_util::{SinkExt, StreamExt};
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use tokio_tungstenite::tungstenite::protocol::Message;
use tokio_tungstenite::WebSocketStream;
use uuid::Uuid;

// 客户端连接信息
#[derive(Debug, Clone)]
pub struct Client {
    pub id: Uuid,
    pub username: Option<String>,
    pub addr: std::net::SocketAddr,
}

// 连接管理器
#[derive(Clone)]
pub struct ConnectionManager {
    clients: Arc<DashMap<Uuid, ClientConnection>>,
}

// 客户端连接包装
pub struct ClientConnection {
    pub client: Client,
    pub sender: tokio::sync::mpsc::UnboundedSender<Message>,
}

impl ConnectionManager {
    pub fn new() -> Self {
        Self {
            clients: Arc::new(DashMap::new()),
        }
    }

    // 添加新连接
    pub async fn add_connection(
        &self,
        stream: WebSocketStream<TcpStream>,
        addr: std::net::SocketAddr,
    ) -> Uuid {
        let client_id = Uuid::new_v4();
        let client = Client {
            id: client_id,
            username: None,
            addr,
        };

        // 创建消息通道
        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
        let (mut ws_sender, mut ws_receiver) = stream.split();

        // 存储连接
        self.clients.insert(
            client_id,
            ClientConnection {
                client: client.clone(),
                sender: tx.clone(),
            },
        );

        // 启动发送任务
        let clients = self.clients.clone();
        tokio::spawn(async move {
            while let Some(message) = rx.recv().await {
                if let Err(e) = ws_sender.send(message).await {
                    eprintln!("发送消息失败: {}", e);
                    break;
                }
            }
            // 连接关闭,移除客户端
            clients.remove(&client_id);
        });

        // 启动接收任务
        let clients = self.clients.clone();
        tokio::spawn(async move {
            while let Some(result) = ws_receiver.next().await {
                match result {
                    Ok(message) => {
                        Self::handle_message(&clients, client_id, message).await;
                    }
                    Err(e) => {
                        eprintln!("接收消息错误: {}", e);
                        break;
                    }
                }
            }
            // 连接关闭,移除客户端
            clients.remove(&client_id);
        });

        client_id
    }

    // 处理接收到的消息
    async fn handle_message(
        clients: &Arc<DashMap<Uuid, ClientConnection>>,
        sender_id: Uuid,
        message: Message,
    ) {
        match message {
            Message::Text(text) => {
                println!("收到消息: {}", text);
                
                // 解析 JSON 消息
                if let Ok(value) = serde_json::from_str::<serde_json::Value>(&text) {
                    if let Some(msg_type) = value.get("type").and_then(|v| v.as_str()) {
                        match msg_type {
                            "chat" => Self::broadcast_message(clients, sender_id, &text).await,
                            "login" => Self::handle_login(clients, sender_id, &value).await,
                            _ => println!("未知消息类型: {}", msg_type),
                        }
                    }
                }
            }
            Message::Close(_) => {
                println!("客户端 {} 断开连接", sender_id);
            }
            _ => {} // 忽略其他类型消息
        }
    }

    // 广播消息给所有客户端
    async fn broadcast_message(
        clients: &Arc<DashMap<Uuid, ClientConnection>>,
        sender_id: Uuid,
        message: &str,
    ) {
        for entry in clients.iter() {
            let client_id = entry.key();
            if *client_id != sender_id {
                if let Err(e) = entry.value().sender.send(Message::Text(message.to_string())) {
                    eprintln!("广播消息失败: {}", e);
                }
            }
        }
    }

    // 处理登录
    async fn handle_login(
        clients: &Arc<DashMap<Uuid, ClientConnection>>,
        client_id: Uuid,
        data: &serde_json::Value,
    ) {
        if let Some(username) = data.get("username").and_then(|v| v.as_str()) {
            if let Some(mut entry) = clients.get_mut(&client_id) {
                entry.client.username = Some(username.to_string());
                println!("用户 {} 登录成功", username);
                
                // 发送欢迎消息
                let welcome_msg = serde_json::json!({
                    "type": "system",
                    "message": format!("欢迎 {} 加入聊天室!", username)
                });
                
                if let Err(e) = entry.sender.send(Message::Text(welcome_msg.to_string())) {
                    eprintln!("发送欢迎消息失败: {}", e);
                }
            }
        }
    }

    // 获取活跃连接数
    pub fn active_connections(&self) -> usize {
        self.clients.len()
    }
}

主服务器实现

use tokio_tungstenite::accept_async;
use tracing::{info, error};

pub struct WebSocketServer {
    listener: TcpListener,
    connection_manager: ConnectionManager,
}

impl WebSocketServer {
    pub async fn new(addr: &str) -> Result<Self, Box<dyn std::error::Error>> {
        let listener = TcpListener::bind(addr).await?;
        let connection_manager = ConnectionManager::new();
        
        info!("WebSocket 服务器启动在 {}", addr);
        
        Ok(Self {
            listener,
            connection_manager,
        })
    }

    pub async fn run(&self) -> Result<(), Box<dyn std::error::Error>> {
        // 启动统计任务
        let manager = self.connection_manager.clone();
        tokio::spawn(async move {
            let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(10));
            loop {
                interval.tick().await;
                let count = manager.active_connections();
                info!("活跃连接数: {}", count);
            }
        });

        loop {
            match self.listener.accept().await {
                Ok((stream, addr)) => {
                    info!("新连接来自: {}", addr);
                    
                    let manager = self.connection_manager.clone();
                    tokio::spawn(async move {
                        match accept_async(stream).await {
                            Ok(ws_stream) => {
                                let _ = manager.add_connection