rust如何使用多线程

0 阅读4分钟

在 Rust 中,多线程编程是其核心优势之一。Rust 通过所有权系统(Ownership)和类型系统,在编译阶段就消除了数据竞争(Data Race),实现了“无畏并发”(Fearless Concurrency)。 以下是 Rust 使用多线程的核心方法、常用模式及代码示例:

  1. 基础:创建线程 (std::thread::spawn)
  2. 使用 std::thread::spawn 函数创建新线程。它接受一个闭包(closure),该闭包包含新线程要执行的代码。
  3. 关键点:通常需要使用 move 关键字,将变量的所有权移入新线程。
  4. 等待线程结束:使用 join() 方法阻塞当前线程,直到子线程执行完毕。
  5. rust

编辑

use std::thread; use std::time::Duration;

fn main() { // 创建一个新线程 let handle = thread::spawn(|| { for i in 1..5 { println!("子线程打印: {}", i); thread::sleep(Duration::from_millis(100)); } });

                                                    // 主线程做一些工作
                                                        for i in 1..3 {
                                                                println!("主线程打印: {}", i);
                                                                        thread::sleep(Duration::from_millis(100));
                                                                            }
                                                                            
                                                                                // 等待子线程结束 (join 返回 Result)
                                                                                    handle.join().unwrap();
                                                                                        
                                                                                            println!("所有线程已完成");
                                                                                            }
                                                                                            2. 线程间通信:消息传递 (Message Passing)
                                                                                            Rust 推崇“不要通过共享内存来通信,而要通过通信来共享内存”。标准库提供了 通道 (Channels) (std::sync::mpsc) 来实现这一点。
                                                                                            mpsc: Multiple Producer, Single Consumer(多生产者,单消费者)。
                                                                                            tx (Transmitter): 发送端。
                                                                                            rx (Receiver): 接收端。
                                                                                            rust
                                                                                            
                                                                                            编辑
                                                                                            
                                                                                            
                                                                                            
                                                                                            use std::thread;
                                                                                            use std::sync::mpsc;
                                                                                            use std::time::Duration;
                                                                                            
                                                                                            fn main() {
                                                                                                // 创建通道
                                                                                                    let (tx, rx) = mpsc::channel();
                                                                                                    
                                                                                                        // 启动线程,移动 tx 到线程中
                                                                                                            thread::spawn(move || {
                                                                                                                    let vals = vec![
                                                                                                                                String::from("你好"),
                                                                                                                                            String::from("来自"),
                                                                                                                                                        String::from("线程"),
                                                                                                                                                                ];
                                                                                                                                                                
                                                                                                                                                                        for val in vals {
                                                                                                                                                                                    tx.send(val).unwrap(); // 发送数据
                                                                                                                                                                                                thread::sleep(Duration::from_millis(200));
                                                                                                                                                                                                        }
                                                                                                                                                                                                                // tx 在这里 drop,通道关闭
                                                                                                                                                                                                                    });
                                                                                                                                                                                                                    
                                                                                                                                                                                                                        // 在主线程接收数据
                                                                                                                                                                                                                            // recv 会阻塞,直到收到消息或通道关闭
                                                                                                                                                                                                                                // 也可以使用 rx.try_recv() 非阻塞接收
                                                                                                                                                                                                                                    for received in rx {
                                                                                                                                                                                                                                            println!("收到: {}", received);
                                                                                                                                                                                                                                                }
                                                                                                                                                                                                                                                }
                                                                                                                                                                                                                                                3. 共享状态:Arc + Mutex
                                                                                                                                                                                                                                                如果必须共享内存(例如多个线程需要修改同一个计数器),Rust 要求使用智能指针来保证安全:
                                                                                                                                                                                                                                                Mutex<T> (互斥锁): 保证同一时间只有一个线程能访问数据。
                                                                                                                                                                                                                                                Arc<T> (原子引用计数): 允许多个线程拥有同一个 Mutex 的所有权(Rc 不是线程安全的,必须用 Arc)。
                                                                                                                                                                                                                                                注意:必须先包裹 Arc,再包裹 Mutex,即 Arc<Mutex<T>>。
                                                                                                                                                                                                                                                rust
                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                编辑
                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                use std::sync::{Arc, Mutex};
                                                                                                                                                                                                                                                use std::thread;
                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                fn main() {
                                                                                                                                                                                                                                                    // 创建共享数据:Arc 用于多线程共享引用,Mutex 用于保证互斥访问
                                                                                                                                                                                                                                                        let counter = Arc::new(Mutex::new(0));
                                                                                                                                                                                                                                                            let mut handles = vec![];
                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                for _ in 0..10 {
                                                                                                                                                                                                                                                                        // 克隆 Arc 指针,增加引用计数,让每个线程都拥有所有权
                                                                                                                                                                                                                                                                                let counter_clone = Arc::clone(&counter);
                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                let handle = thread::spawn(move || {
                                                                                                                                                                                                                                                                                                            // lock() 获取锁,如果失败会 panic (也可以用 match 处理错误)
                                                                                                                                                                                                                                                                                                                        // num 是一个 MutexGuard,解引用后可修改内部数据
                                                                                                                                                                                                                                                                                                                                    let mut num = counter_clone.lock().unwrap();
                                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                                            *num += 1;
                                                                                                                                                                                                                                                                                                                                                                        println!("当前计数: {}", *num);
                                                                                                                                                                                                                                                                                                                                                                                    // num 离开作用域时,锁自动释放
                                                                                                                                                                                                                                                                                                                                                                                            });
                                                                                                                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                                                                                                                            handles.push(handle);
                                                                                                                                                                                                                                                                                                                                                                                                                }
                                                                                                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                                                                                                    // 等待所有线程完成
                                                                                                                                                                                                                                                                                                                                                                                                                        for handle in handles {
                                                                                                                                                                                                                                                                                                                                                                                                                                handle.join().unwrap();
                                                                                                                                                                                                                                                                                                                                                                                                                                    }
                                                                                                                                                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                                                                                                                                                        println!("最终结果: {}", *counter.lock().unwrap());
                                                                                                                                                                                                                                                                                                                                                                                                                                        }
                                                                                                                                                                                                                                                                                                                                                                                                                                        4. 进阶:线程池 (Thread Pool)
                                                                                                                                                                                                                                                                                                                                                                                                                                        频繁创建和销毁线程开销很大。在实际生产环境中(如 Web 服务器),通常使用线程池。 
                                                                                                                                                                                                                                                                                                                                                                                                                                        Rust 标准库没有直接提供线程池,但可以通过上述原语自行实现,或使用流行的 crate:
                                                                                                                                                                                                                                                                                                                                                                                                                                        推荐库: rayon (数据并行), tokio (异步运行时,也处理并发), threadpool。
                                                                                                                                                                                                                                                                                                                                                                                                                                        使用 rayon 进行数据并行的简单示例(比手动管理线程更简单高效):
                                                                                                                                                                                                                                                                                                                                                                                                                                        rust
                                                                                                                                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                                                                                                        编辑
                                                                                                                                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                                                                                                        // Cargo.toml: rayon = "1.8"
                                                                                                                                                                                                                                                                                                                                                                                                                                        use rayon::prelude::*;
                                                                                                                                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                                                                                                        fn main() { 
                                                                                                                                                                                                                                                                                                                                                                                                                                            let numbers: Vec<i32> = (0..10000).collect();
                                                                                                                                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                                                                                                                                    // par_iter() 创建并行迭代器,自动利用多线程处理
                                                                                                                                                                                                                                                                                                                                                                                                                                                        let sum: i32 = numbers.par_iter()
                                                                                                                                                                                                                                                                                                                                                                                                                                                                .map(|&x| x * x)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                        .sum();
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    println!("平方和: {}", sum);
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    }
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    总结与最佳实践
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    首选消息传递:尽量使用 mpsc 通道在不同线程间传递数据,避免共享状态带来的复杂性。
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    共享状态需谨慎:如果必须共享,务必使用 Arc<Mutex<T>>。尽量减少锁持有的时间,避免死锁。 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    理解 move:在线程闭包中,明确使用 move 捕获变量所有权,这是 Rust 线程安全的基础。
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    处理错误:join() 和 lock() 都会返回 Result,生产代码中应妥善处理线程恐慌(Panic)或锁中毒(PoisonError)。
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    异步 vs 多线程:如果是 I/O 密集型任务(如网络请求),考虑使用 tokio 或 async-std 的异步模型;如果是 CPU 密集型任务,使用原生多线程或 rayon。
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    Rust 的编译器会在你写出有数据竞争风险的代码时报错,虽然上手曲线较陡,但一旦编译通过,运行时的并发安全性极高。