在Rust语言中,通过 trait 实现“继承”

1,499 阅读3分钟

在工作中,经常会对数据库的表进行增、删、改、查操作。如果能够对这些基础的方法进行重用

为什么要实现“继承”:

因为有提高了代码的复用性,提升开发效率的需求。

如果把对表的增删改查,抽象维护在一个trait中,通过impl的方式赋值给表结构体,则结构体则拥有了trait中的方法,能够大大提升开发效率。

通过trait实现“继承”(本质是代码复用)

实现原理:通过 【trait组合 + 泛型】 实现对数据库增、删、改、查方法的抽象.

具体实现:以mongodb为例

  1. 在rust中创建结构体 Table1 和 Table2,分别对应操作mongodb中table1和table2

Table1 定义如下:

#[derive(Debug, Serialize, Deserialize, Default)]
pub struct Table1 {
    pub title: String,
    pub author: String,
}

Table2 定义如下:

#[derive(Debug, Serialize, Deserialize, Default)]
pub struct Table2 {
    pub name: String,
    pub age: Option<u32>,
}
  1. 在trait IExample 中定义对MongoDB的表的增删改查

trait IExample 定义如下: 在IExample中实现增删改查

use core::any::type_name;
use mongodb::bson::{doc, Bson, Document};
use mongodb::sync::{Client, Database};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};

// 定义DB增删改查接口
pub trait IExample<T>
where
    T: Serialize + DeserializeOwned + Unpin + Send + Sync,
{
    fn insert(doc: T) -> Bson {
        let table_name = table_name::<T>();
        let collection = db().collection::<T>(table_name.as_str());
        collection.insert_one(doc, None).unwrap().inserted_id
    }

    fn update(filter: Document, doc: Document) -> u64 {
        let table_name = table_name::<T>();
        let collection = db().collection::<T>(table_name.as_str());
        collection
            .update_one(filter, doc, None)
            .unwrap()
            .modified_count
    }

    fn find(filter: Document) -> T {
        let table_name = table_name::<T>();
        let collection = db().collection::<T>(table_name.as_str());
        collection.find_one(filter, None).unwrap().unwrap()
    }

    fn delete(filter: Document) -> u64 {
        let table_name = table_name::<T>();
        let collection = db().collection::<T>(table_name.as_str());
        collection.delete_one(filter, None).unwrap().deleted_count
    }
}
  1. 用trait IExample 实现 Table1,Table2

核心代码:将已经实现功能的trait,impl给struct

// 实现Table1增删改查
impl IExample<Table1> for Table1 {}
// 实现Table2增删改查
impl IExample<Table2> for Table2 {}
  1. 调用、测试

fn main() {
    // 操作Table1
    let t1 = Table1 {
        title: "The Grapes of Wrath".to_string(),
        author: "John Steinbeck".to_string(),
    };
    Table1::insert(t1); //新增

    let filter = doc! {
       "title": "The Grapes of Wrath"
    };
    let get_t1 = Table1::find(filter); //查找
    println!("{:?}", &get_t1);

    // 操作Table2
    let t2 = Table2 {
        name: "xunzi".to_string(),
        age: Some(235),
    };
    Table2::insert(t2); //新增

    let filter = doc! {
       "name": "xunzi"
    };
    let update = doc! {
       "$set": {
           "age":Some(75),
       }
    };
    let get_t2 = Table2::update(filter, update); //更新
    println!("{:?}", &get_t2);
}

以上代码实例即通过trait方式实现了增删改查的代码共用

完整代码 src/example/mod.rs

use core::any::type_name;
use mongodb::bson::{doc, Bson, Document};
use mongodb::sync::{Client, Database};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};

// 定义DB增删改查接口
pub trait IExample<T>
where
    T: Serialize + DeserializeOwned + Unpin + Send + Sync,
{
    fn insert(doc: T) -> Bson {
        let table_name = table_name::<T>();
        let collection = db().collection::<T>(table_name.as_str());
        collection.insert_one(doc, None).unwrap().inserted_id
    }

    fn update(filter: Document, doc: Document) -> u64 {
        let table_name = table_name::<T>();
        let collection = db().collection::<T>(table_name.as_str());
        collection
            .update_one(filter, doc, None)
            .unwrap()
            .modified_count
    }

    fn find(filter: Document) -> T {
        let table_name = table_name::<T>();
        let collection = db().collection::<T>(table_name.as_str());
        collection.find_one(filter, None).unwrap().unwrap()
    }

    fn delete(filter: Document) -> u64 {
        let table_name = table_name::<T>();
        let collection = db().collection::<T>(table_name.as_str());
        collection.delete_one(filter, None).unwrap().deleted_count
    }
}

// 定义Table1
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct Table1 {
    pub title: String,
    pub author: String,
}

// 定义Table2
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct Table2 {
    pub name: String,
    pub age: Option<u32>,
}

// 实现Table1增删改查
impl IExample<Table1> for Table1 {}
// 实现Table2增删改查
impl IExample<Table2> for Table2 {}

/**************************************************
 *  无关代码
 * ************************************************/
// 获取类型名称
fn table_name<T>() -> String {
    let table = type_name::<T>().split("::").last().unwrap();
    table.to_lowercase()
}
// 获取指定数据库
fn db() -> Database {
    // "mongodb://localhost:27017"
    let url = "mongodb://admin:X%5DpCY%29g8Hs%259hlss%2CY%5Di@localhost:27017";
    let client = Client::with_uri_str(url).unwrap();
    let database = client.database("mydb");
    database
}

完整代码库

github