引言
在实时应用日益普及的今天,WebSocket 协议已成为构建聊天应用、实时协作工具、在线游戏和金融交易系统的核心技术。虽然 Node.js 和 Go 等语言在 WebSocket 服务器开发中很常见,但 Rust 凭借其卓越的性能和内存安全性,正在成为构建高性能实时服务的理想选择。
本文将带你从零开始,使用 Rust 和 Tokio 异步运行时,构建一个完整的、可扩展的 WebSocket 服务器。我们将不仅关注基础实现,还会深入探讨连接管理、消息广播、错误处理等生产级功能。
为什么选择 Rust 和 Tokio?
Rust 的优势
- 零成本抽象:高性能无需牺牲安全性
- 内存安全:编译时保证,无需垃圾回收
- 并发安全:所有权系统防止数据竞争
- 丰富的生态系统:Cargo 包管理器提供了大量高质量库
Tokio 的优势
- 异步运行时:专为 I/O 密集型应用设计
- 高并发:单线程可处理数十万连接
- 生态系统完善:与 Rust 异步生态完美集成
项目搭建
首先创建新项目并添加依赖:
cargo new websocket-server
cd websocket-server
编辑 Cargo.toml:
[package]
name = "websocket-server"
version = "0.1.0"
edition = "2021"
[dependencies]
tokio = { version = "1.0", features = ["full"] }
tungstenite = "0.20.0"
futures-util = "0.3.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tracing = "0.1.0"
tracing-subscriber = "0.3.0"
arc-swap = "1.0.0"
dashmap = "5.0.0"
核心架构设计
我们的 WebSocket 服务器将包含以下核心组件:
- 连接管理器:管理所有活跃连接
- 消息路由器:处理消息的接收和分发
- 广播系统:向多个客户端发送消息
- 错误处理:优雅地处理连接异常
实现连接管理器
use dashmap::DashMap;
use futures_util::{SinkExt, StreamExt};
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use tokio_tungstenite::tungstenite::protocol::Message;
use tokio_tungstenite::WebSocketStream;
use uuid::Uuid;
// 客户端连接信息
#[derive(Debug, Clone)]
pub struct Client {
pub id: Uuid,
pub username: Option<String>,
pub addr: std::net::SocketAddr,
}
// 连接管理器
#[derive(Clone)]
pub struct ConnectionManager {
clients: Arc<DashMap<Uuid, ClientConnection>>,
}
// 客户端连接包装
pub struct ClientConnection {
pub client: Client,
pub sender: tokio::sync::mpsc::UnboundedSender<Message>,
}
impl ConnectionManager {
pub fn new() -> Self {
Self {
clients: Arc::new(DashMap::new()),
}
}
// 添加新连接
pub async fn add_connection(
&self,
stream: WebSocketStream<TcpStream>,
addr: std::net::SocketAddr,
) -> Uuid {
let client_id = Uuid::new_v4();
let client = Client {
id: client_id,
username: None,
addr,
};
// 创建消息通道
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let (mut ws_sender, mut ws_receiver) = stream.split();
// 存储连接
self.clients.insert(
client_id,
ClientConnection {
client: client.clone(),
sender: tx.clone(),
},
);
// 启动发送任务
let clients = self.clients.clone();
tokio::spawn(async move {
while let Some(message) = rx.recv().await {
if let Err(e) = ws_sender.send(message).await {
eprintln!("发送消息失败: {}", e);
break;
}
}
// 连接关闭,移除客户端
clients.remove(&client_id);
});
// 启动接收任务
let clients = self.clients.clone();
tokio::spawn(async move {
while let Some(result) = ws_receiver.next().await {
match result {
Ok(message) => {
Self::handle_message(&clients, client_id, message).await;
}
Err(e) => {
eprintln!("接收消息错误: {}", e);
break;
}
}
}
// 连接关闭,移除客户端
clients.remove(&client_id);
});
client_id
}
// 处理接收到的消息
async fn handle_message(
clients: &Arc<DashMap<Uuid, ClientConnection>>,
sender_id: Uuid,
message: Message,
) {
match message {
Message::Text(text) => {
println!("收到消息: {}", text);
// 解析 JSON 消息
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&text) {
if let Some(msg_type) = value.get("type").and_then(|v| v.as_str()) {
match msg_type {
"chat" => Self::broadcast_message(clients, sender_id, &text).await,
"login" => Self::handle_login(clients, sender_id, &value).await,
_ => println!("未知消息类型: {}", msg_type),
}
}
}
}
Message::Close(_) => {
println!("客户端 {} 断开连接", sender_id);
}
_ => {} // 忽略其他类型消息
}
}
// 广播消息给所有客户端
async fn broadcast_message(
clients: &Arc<DashMap<Uuid, ClientConnection>>,
sender_id: Uuid,
message: &str,
) {
for entry in clients.iter() {
let client_id = entry.key();
if *client_id != sender_id {
if let Err(e) = entry.value().sender.send(Message::Text(message.to_string())) {
eprintln!("广播消息失败: {}", e);
}
}
}
}
// 处理登录
async fn handle_login(
clients: &Arc<DashMap<Uuid, ClientConnection>>,
client_id: Uuid,
data: &serde_json::Value,
) {
if let Some(username) = data.get("username").and_then(|v| v.as_str()) {
if let Some(mut entry) = clients.get_mut(&client_id) {
entry.client.username = Some(username.to_string());
println!("用户 {} 登录成功", username);
// 发送欢迎消息
let welcome_msg = serde_json::json!({
"type": "system",
"message": format!("欢迎 {} 加入聊天室!", username)
});
if let Err(e) = entry.sender.send(Message::Text(welcome_msg.to_string())) {
eprintln!("发送欢迎消息失败: {}", e);
}
}
}
}
// 获取活跃连接数
pub fn active_connections(&self) -> usize {
self.clients.len()
}
}
主服务器实现
use tokio_tungstenite::accept_async;
use tracing::{info, error};
pub struct WebSocketServer {
listener: TcpListener,
connection_manager: ConnectionManager,
}
impl WebSocketServer {
pub async fn new(addr: &str) -> Result<Self, Box<dyn std::error::Error>> {
let listener = TcpListener::bind(addr).await?;
let connection_manager = ConnectionManager::new();
info!("WebSocket 服务器启动在 {}", addr);
Ok(Self {
listener,
connection_manager,
})
}
pub async fn run(&self) -> Result<(), Box<dyn std::error::Error>> {
// 启动统计任务
let manager = self.connection_manager.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(10));
loop {
interval.tick().await;
let count = manager.active_connections();
info!("活跃连接数: {}", count);
}
});
loop {
match self.listener.accept().await {
Ok((stream, addr)) => {
info!("新连接来自: {}", addr);
let manager = self.connection_manager.clone();
tokio::spawn(async move {
match accept_async(stream).await {
Ok(ws_stream) => {
let _ = manager.add_connection