第五章:泛型编程与 trait 抽象

79 阅读10分钟

第五章:泛型编程与 trait 抽象

教学目标

  • 理解泛型编程思想与 trait 机制的核心概念
  • 掌握泛型类型、trait 定义与实现的使用方法
  • 实现通用存储接口与多存储方式的切换
  • 理解接口抽象与多态在 Rust 中的应用

核心知识点

1. 泛型类型

函数泛型与类型参数

泛型允许我们编写不指定具体类型的函数,从而提高代码的复用性。在函数定义中使用类型参数(如T)来表示通用类型。

// 泛型函数:交换两个值
fn swap<T>(a: &mut T, b: &mut T) {
    let temp = *a;
    *a = *b;
    *b = temp;
}
// 带有多个类型参数的泛型函数
fn pair<T, U>(a: T, b: U) -> (T, U) {
    (a, b)
}
fn main() {
    let mut x = 10;
    let mut y = 20;
    swap(&mut x, &mut y);
    println!("x = {}, y = {}", x, y);  // 输出: x = 20, y = 10
    
    let pair = pair("Rust", 1.0);
    println!("{}: {:?}", pair.0, pair.1);  // 输出: Rust: 1.0
}
结构体泛型与生命周期参数

结构体也可以使用泛型,允许结构体存储不同类型的数据。生命周期参数用于确保引用的有效性。

// 泛型结构体:存储任意类型的链表节点
struct Node<T> {
    data: T,
    next: Option<Box<Node<T>>>,
}
// 带有生命周期参数的泛型结构体
struct Reference<'a, T> {
    value: &'a T,
}
fn main() {
    // 使用i32类型的Node
    let node1 = Node {
        data: 10,
        next: Some(Box::new(Node {
            data: 20,
            next: None,
        })),
    };
    
    // 使用String类型的Node
    let node2 = Node {
        data: "Hello".to_string(),
        next: Some(Box::new(Node {
            data: "World".to_string(),
            next: None,
        })),
    };
    
    let x = 5;
    let ref1 = Reference { value: &x };
    println!("{}", ref1.value);  // 输出: 5
}
泛型约束(where 子句)

使用where子句为泛型参数添加约束,限制泛型类型必须实现特定的 trait。

// 泛型函数,要求T实现Display trait
fn print<T: std::fmt::Display>(value: T) {
    println!("值: {}", value);
}
// 使用where子句的泛型函数
fn add<T, U, V>(a: T, b: U) -> V
where
    T: std::ops::Add<U, Output = V>,
{
    a + b
}
fn main() {
    print(10);              // 输出: 值: 10
    print("Rust".to_string());  // 输出: 值: Rust
    
    let sum = add(5, 3.5);  // i32 + f64 = f64
    println!("和: {}", sum);  // 输出: 和: 8.5
}

2. trait 机制

trait 定义与抽象方法

trait 用于定义一组方法的签名,实现 trait 的类型必须提供这些方法的具体实现。

// 定义Shape trait,包含计算面积和周长的方法
trait Shape {
    fn area(&self) -> f64;
    fn perimeter(&self) -> f64;
}
// 定义Drawable trait,包含绘制方法和默认实现
trait Drawable {
    fn draw(&self) {
        println!("绘制图形");
    }
    fn draw_with_color(&self, color: &str);
}
// 矩形结构体
struct Rectangle {
    width: f64,
    height: f64,
}
// 实现Shape trait
impl Shape for Rectangle {
    fn area(&self) -> f64 {
        self.width * self.height
    }
    
    fn perimeter(&self) -> f64 {
        2.0 * (self.width + self.height)
    }
}
// 实现Drawable trait
impl Drawable for Rectangle {
    fn draw_with_color(&self, color: &str) {
        println!("用{}颜色绘制矩形", color);
    }
}
fn main() {
    let rect = Rectangle { width: 4.0, height: 3.0 };
    println!("面积: {}", rect.area());       // 输出: 面积: 12
    println!("周长: {}", rect.perimeter()); // 输出: 周长: 14
    rect.draw();                            // 输出: 绘制图形
    rect.draw_with_color("红色");           // 输出: 用红色颜色绘制矩形
}
trait 实现与默认方法

trait 可以提供方法的默认实现,实现 trait 时可以选择覆盖或使用默认实现。

// 定义Iterator trait的简化版本
trait MyIterator {
    type Item;
    fn next(&mut self) -> Option<Self::Item>;
    
    // 默认实现:遍历所有元素
    fn for_each<F>(&mut self, mut f: F)
    where
        F: FnMut(Self::Item),
    {
        while let Some(item) = self.next() {
            f(item);
        }
    }
}
// 实现MyIterator for Vec<i32>
impl MyIterator for Vec<i32> {
    type Item = i32;
    
    fn next(&mut self) -> Option<Self::Item> {
        self.pop()
    }
}
fn main() {
    let mut vec = vec![1, 2, 3, 4, 5];
    vec.for_each(|x| println!("{}", x));  // 输出: 5 4 3 2 1
}
泛型 trait 与关联类型

泛型 trait 可以包含关联类型,用于抽象类型参数,使 trait 更灵活。

// 定义泛型trait,包含关联类型
trait Container<T> {
    type Iter: Iterator<Item = T>;
    fn new() -> Self;
    fn add(&mut self, item: T);
    fn iter(&self) -> Self::Iter;
}
// 实现Container for Vec<T>
impl<T> Container<T> for Vec<T> {
    type Iter = std::slice::Iter<T>;
    
    fn new() -> Self {
        Vec::new()
    }
    
    fn add(&mut self, item: T) {
        self.push(item)
    }
    
    fn iter(&self) -> Self::Iter {
        self.iter()
    }
}
fn main() {
    let mut vec: Vec<i32> = Container::new();
    vec.add(10);
    vec.add(20);
    
    for item in vec.iter() {
        println!("{}", item);  // 输出: 10 20
    }
}

3. 接口抽象与多态

通过 trait 实现存储接口抽象

使用 trait 定义存储接口,不同的存储方式实现该 trait,从而实现接口抽象。

use std::collections::HashMap;
use thiserror::Error;
// 定义存储接口trait
trait TaskStorage {
    type Error;
    fn save(&self, tasks: &HashMap<u32, Task>, next_id: u32) -> Result<(), Self::Error>;
    fn load(&self) -> Result<(HashMap<u32, Task>, u32), Self::Error>;
}
// 任务结构体
struct Task {
    id: u32,
    title: String,
    // 其他字段...
}
// 文件存储实现
struct FileStorage {
    path: String,
}
impl TaskStorage for FileStorage {
    type Error = std::io::Error;
    
    fn save(&self, tasks: &HashMap<u32, Task>, next_id: u32) -> Result<(), Self::Error> {
        // 实现文件保存逻辑
        Ok(())
    }
    
    fn load(&self) -> Result<(HashMap<u32, Task>, u32), Self::Error> {
        // 实现文件加载逻辑
        Ok((HashMap::new(), 1))
    }
}
// 内存存储实现
struct MemoryStorage {
    tasks: Option<(HashMap<u32, Task>, u32)>,
}
impl TaskStorage for MemoryStorage {
    type Error = &'static str;
    
    fn save(&self, tasks: &HashMap<u32, Task>, next_id: u32) -> Result<(), Self::Error> {
        self.tasks = Some((tasks.clone(), next_id));
        Ok(())
    }
    
    fn load(&self) -> Result<(HashMap<u32, Task>, u32), Self::Error> {
        match &self.tasks {
            Some((tasks, next_id)) => Ok((tasks.clone(), *next_id)),
            None => Ok((HashMap::new(), 1)),
        }
    }
}
多存储方式的统一接口

通过泛型实现支持多种存储方式的任务管理器,提高代码的可扩展性。

// 泛型任务管理器,支持任意实现TaskStorage的存储方式
struct TaskManager<S> {
    tasks: HashMap<u32, Task>,
    next_id: u32,
    storage: S,
}
impl<S: TaskStorage> TaskManager<S> {
    fn new(storage: S) -> Self {
        let (tasks, next_id) = storage.load().unwrap_or_else(|e| {
            eprintln!("加载任务失败: {}", e);
            (HashMap::new(), 1)
        });
        TaskManager { tasks, next_id, storage }
    }
    
    fn add_task(&mut self, task: Task) {
        self.tasks.insert(self.next_id, task);
        self.next_id += 1;
    }
    
    fn save(&self) -> Result<(), S::Error> {
        self.storage.save(&self.tasks, self.next_id)
    }
}
fn main() {
    // 使用文件存储
    let file_storage = FileStorage { path: "tasks.json".to_string() };
    let mut file_manager = TaskManager::new(file_storage);
    
    // 使用内存存储
    let memory_storage = MemoryStorage { tasks: None };
    let mut memory_manager = TaskManager::new(memory_storage);
}
trait 对象与动态分发

使用 trait 对象实现动态多态,在运行时确定具体的类型。

use std::any::Any;
// 定义Animal trait
trait Animal {
    fn speak(&self);
    fn as_any(&self) -> &dyn Any;
}
// 狗结构体
struct Dog {
    name: String,
}
impl Animal for Dog {
    fn speak(&self) {
        println!("{}: 汪汪汪", self.name);
    }
    
    fn as_any(&self) -> &dyn Any {
        self
    }
}
// 猫结构体
struct Cat {
    name: String,
}
impl Animal for Cat {
    fn speak(&self) {
        println!("{}: 喵喵喵", self.name);
    }
    
    fn as_any(&self) -> &dyn Any {
        self
    }
}
fn main() {
    let animals: Vec<Box<dyn Animal>> = vec![
        Box::new(Dog { name: "旺财".to_string() }),
        Box::new(Cat { name: "咪咪".to_string() }),
    ];
    
    for animal in animals {
        animal.speak();
    }
    
    // 动态类型转换
    if let Some(dog) = animals[0].as_any().downcast_ref::<Dog>() {
        println!("这是一只狗: {}", dog.name);
    }
}

项目实战:实现通用存储接口

1. 定义 TaskStorage trait

storage.rs中定义TaskStorage trait,作为所有存储方式的统一接口。

// src/storage.rs
use crate::task::{Task, TaskStatus};
use std::collections::HashMap;
use thiserror::Error;
// 定义存储接口trait,使用关联类型抽象错误类型
pub trait TaskStorage {
    type Error;
    fn save(&self, tasks: &HashMap<u32, Task>, next_id: u32) -> Result<(), Self::Error>;
    fn load(&self) -> Result<(HashMap<u32, Task>, u32), Self::Error>;
}

2. 重构 FileStorage 实现 TaskStorage

修改FileStorage以实现TaskStorage trait,确保错误类型和方法签名匹配。

// src/storage.rs
use super::TaskStorage;
use crate::task::{Task, TaskStatus};
use std::collections::HashMap;
use std::fs::File;
use std::io::{Read, Write};
use std::path::PathBuf;
use serde::{Deserialize, Serialize};
use thiserror::Error;
// 存储模块的自定义错误类型
#[derive(Error, Debug)]
pub enum StorageError {
    #[error("文件操作错误: {0}")]
    FileError(#[from] std::io::Error),
    #[error("JSON解析错误: {0}")]
    JsonError(#[from] serde_json::Error),
    #[error("任务数据格式错误: {0}")]
    DataFormatError(String),
}
// 序列化任务数据的结构体
#[derive(Serialize, Deserialize, Debug)]
struct TaskStorageData {
    tasks: HashMap<u32, Task>,
    next_id: u32,
}
// 文件存储实现
pub struct FileStorage {
    path: PathBuf,
}
impl TaskStorage for FileStorage {
    type Error = StorageError;
    
    fn save(&self, tasks: &HashMap<u32, Task>, next_id: u32) -> Result<(), Self::Error> {
        let storage = TaskStorageData {
            tasks: tasks.clone(),
            next_id,
        };
        
        // 序列化为JSON
        let data = serde_json::to_vec(&storage)?;
        
        // 写入文件
        let mut file = File::create(&self.path)?;
        file.write_all(&data)?;
        
        Ok(())
    }
    
    fn load(&self) -> Result<(HashMap<u32, Task>, u32), Self::Error> {
        // 文件不存在时返回空任务集合和next_id=1
        if !self.path.exists() {
            return Ok((HashMap::new(), 1));
        }
        
        // 读取文件内容
        let mut file = File::open(&self.path)?;
        let mut data = Vec::new();
        file.read_to_end(&mut data)?;
        
        // 反序列化JSON
        let storage: TaskStorageData = serde_json::from_slice(&data)?;
        
        Ok((storage.tasks, storage.next_id))
    }
}

3. 实现 MemoryStorage 内存存储

添加MemoryStorage结构体,实现TaskStorage trait,用于内存中的数据存储(主要用于测试)。

// src/storage.rs
use super::TaskStorage;
use crate::task::{Task, TaskStatus};
use std::collections::HashMap;
use std::error::Error;
use std::fmt;
// 内存存储实现,用于测试和演示泛型
pub struct MemoryStorage {
    tasks: Option<(HashMap<u32, Task>, u32)>,
}
// 内存存储的错误类型
#[derive(Debug)]
pub struct MemoryError;
impl fmt::Display for MemoryError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "内存存储错误")
    }
}
impl Error for MemoryError {}
impl TaskStorage for MemoryStorage {
    type Error = MemoryError;
    
    fn save(&self, tasks: &HashMap<u32, Task>, next_id: u32) -> Result<(), Self::Error> {
        self.tasks = Some((tasks.clone(), next_id));
        Ok(())
    }
    
    fn load(&self) -> Result<(HashMap<u32, Task>, u32), Self::Error> {
        match &self.tasks {
            Some((tasks, next_id)) => Ok((tasks.clone(), *next_id)),
            None => Ok((HashMap::new(), 1)),
        }
    }
}

4. 更新 TaskManager 使用泛型存储

修改cli.rs中的TaskManager,使其成为泛型结构体,接受任意实现TaskStorage的存储类型。

// src/cli.rs
use crate::storage::TaskStorage;
use crate::task::{Task, TaskStatus};
use std::collections::HashMap;
use std::time::SystemTime;
use std::path::PathBuf;
// 泛型TaskManager,支持任意实现TaskStorage的存储类型
pub struct TaskManager<S> {
    tasks: HashMap<u32, Task>,
    next_id: u32,
    storage: S,
}
impl<S: TaskStorage> TaskManager<S> {
    // 从存储创建TaskManager实例
    pub fn new(storage: S) -> Result<Self, S::Error> {
        let (tasks, next_id) = storage.load()?;
        
        Ok(TaskManager {
            tasks,
            next_id,
            storage,
        })
    }
    
    // 添加新任务
    pub fn add_task(&mut self, title: String, description: String, due_date: SystemTime) {
        let task = Task::new(
            self.next_id,
            title,
            description,
            due_date,
        );
        self.tasks.insert(self.next_id, task);
        println!("任务已添加,ID: {}", self.next_id);
        self.next_id += 1;
    }
    
    // 列出所有任务
    pub fn list_tasks(&self) {
        if self.tasks.is_empty() {
            println!("暂无任务");
            return;
        }
        
        println!("ID\t状态\t标题\t\t截止日期");
        println!("----------------------------------------");
        
        for (_, task) in &self.tasks {
            let status = match task.status {
                TaskStatus::Todo => "待办",
                TaskStatus::InProgress => "进行中",
                TaskStatus::Completed => "已完成",
            };
            // 转换截止日期为字符串
            let due_date = match task.due_date.duration_since(SystemTime::UNIX_EPOCH) {
                Ok(dur) => format!("{}", dur.as_secs()),
                Err(_) => "未知日期".to_string(),
            };
            println!("{}\t{}\t{}\t{}", task.id, status, task.title, due_date);
        }
    }
    
    // 更新任务状态
    pub fn update_task_status(&mut self, task_id: u32, status: TaskStatus) {
        if let Some(task) = self.tasks.get_mut(&task_id) {
            task.update_status(status);
            println!("任务状态已更新");
        } else {
            println!("未找到 ID 为 {} 的任务", task_id);
        }
    }
    
    // 保存任务数据到存储
    pub fn save(&self) -> Result<(), S::Error> {
        self.storage.save(&self.tasks, self.next_id)
    }
}

5. 更新 main.rs 使用不同存储方式

修改main.rs,演示如何使用FileStorage和MemoryStorage两种存储方式。

// src/main.rs
mod task;
mod cli;
mod storage;
use cli::TaskManager;
use storage::{FileStorage, MemoryStorage};
use task::TaskStatus;
use std::io::{self, BufRead};
use std::path::PathBuf;
use std::time::SystemTime;
fn main() {
    // 示例1:使用文件存储
    println!("=== 使用文件存储 ===");
    let storage_path = PathBuf::from("tasks.json");
    let file_storage = FileStorage::new(storage_path.clone());
    
    let mut file_manager = match TaskManager::new(file_storage) {
        Ok(manager) => manager,
        Err(e) => {
            eprintln!("加载文件存储失败: {}", e);
            return;
        }
    };
    
    // 添加任务
    file_manager.add_task(
        "学习Rust".to_string(),
        "完成第五章内容".to_string(),
        SystemTime::now() + std::time::Duration::from_secs(86400),
    );
    
    // 保存任务
    if let Err(e) = file_manager.save() {
        eprintln!("保存文件存储失败: {}", e);
    }
    
    // 示例2:使用内存存储
    println!("\n=== 使用内存存储 ===");
    let memory_storage = MemoryStorage { tasks: None };
    let mut memory_manager = match TaskManager::new(memory_storage) {
        Ok(manager) => manager,
        Err(e) => {
            eprintln!("加载内存存储失败: {}", e);
            return;
        }
    };
    
    // 添加任务
    memory_manager.add_task(
        "测试任务".to_string(),
        "内存存储测试".to_string(),
        SystemTime::now() + std::time::Duration::from_secs(3600),
    );
    
    // 更新任务状态
    memory_manager.update_task_status(1, TaskStatus::Completed);
    
    // 列出任务
    memory_manager.list_tasks();
    
    // 保存到内存
    if let Err(e) = memory_manager.save() {
        eprintln!("保存内存存储失败: {}", e);
    }
    
    // 重新加载内存存储(模拟程序重启)
    let memory_storage_loaded = MemoryStorage { tasks: None };
    let mut memory_manager_loaded = match TaskManager::new(memory_storage_loaded) {
        Ok(manager) => manager,
        Err(e) => {
            eprintln!("重新加载内存存储失败: {}", e);
            return;
        }
    };
    
    // 列出重新加载后的任务
    println!("\n=== 重新加载内存存储 ===");
    memory_manager_loaded.list_tasks();
}

6. 编译与测试

编译并运行程序,观察文件存储和内存存储的不同表现:

cargo build
cargo run

程序运行示例

=== 使用文件存储 ===
任务已添加,ID: 1
=== 使用内存存储 ===
任务已添加,ID: 1
ID        状态        标题                截止日期
----------------------------------------
1        已完成        测试任务        1687683600
=== 重新加载内存存储 ===
ID        状态        标题                截止日期
----------------------------------------
1        已完成        测试任务        1687683600

实践作业

实现基于数据库(SQLite)的存储方式,遵循TaskStorage接口,具体要求:

  1. storage.rs中添加SqliteStorage结构体
  1. 使用rusqlite库实现 SQLite 数据库操作
  1. 实现TaskStorage trait for SqliteStorage
  1. main.rs中添加使用SqliteStorage的示例
  1. 测试数据库存储的添加、更新和查询功能
# Cargo.toml
[ dependencies ]
rusqlite = "0.29"
// 在storage.rs中添加SqliteStorage实现
// 在main.rs中添加使用SqliteStorage的代码
fn main() {
    // 测试SqliteStorage功能
}

通过完成这个作业,你将进一步巩固泛型编程和 trait 机制的知识,学习如何通过接口抽象实现不同的存储策略,提高代码的可扩展性和复用性。