Cornucopia:基于PostgresSQL查询生成类型检查的Rust代码

114 阅读1分钟

简介

源地址:github.com/cornucopia-…

cornucopia 是一个基于 tokio-postgres 的小型命令行工具,可以将PostgreSQL查询转换为正确的类型检查的Rust代码。

  1. 编写PostgreSQL查询语句
  2. 使用 cornucopia 生成Rust代码
  3. 在项目中使用生成的Rust代码进行SQL查询操作

定义表结构

CREATE TABLE users
(
    id    SERIAL PRIMARY KEY,
    uuid    UUID NOT NULL,
    username    VARCHAR NOT NULL UNIQUE,
    hashed_password    VARCHAR NOT NULL,
    email    VARCHAR UNIQUE,
    created_at    TIMESTAMP NOT NULL DEFAULT NOW(),
    updated_at    TIMESTAMP
)

查询语句

--! get_users : (updated_at?)
SELECT id,
       uuid,
       username,
       hashed_password,
       email,
       created_at,
       updated_at
FROM users;

Rust实现

Rust代码生成

// build.rs
use std::path::Path;

fn cornucopia() {
    let config = Config::parse("config.toml").expect("Failed to parse configuration file");
    
    let queries_path = "queries";
    let out_dir = env::var("OUT_DIR").unwrap();
    let file_path = Path::new(&out_dir).join("cornucopia.rs");
    
    let output = std::process::Command::new("cornucopia")
        .arg("-q")
        .arg(queries_path)
        .arg("-d")
        .arg(&file_path)
        .arg("live")
        .arg(&config.storage.database_url)
        .output()
        .unwrap()
        
    if !output.status.success() {
        panic!("{}", &std::str::from_utf8(&output.stderr).unwrap());
    }
}

创建连接池

pub fn create_pool(database_url: &str) -> Pool {
    let config = tokio_postgres::Config:from_str(database_url).unwrap();
    
    let manager = if config.get_ssl_mode() != tokio_postgres::config::SslMode::Disable {
        let tls_config = rustls::ClientConfig::builder()
            .dangerous()
            .with_custom_certificate_verifier(Arc::new(DummyTlsVerifier))
            .with_no_client_auth();
            
        let tls = tokio_postgres_rustls::MakeRustConnect::new(tls_config);
        deadpool_postgres::Manager::new(config, tls)
    } else {
        deadpool_postgres::Manager::new(config, tokio_postgres::NoTls)
    };
    
    Pool::builder(manager).build().unwrap()
}

数据访问层使用

build.rs 会基于SQL语句在 queries 路径下自动生成Rust代码

use db::queries::user::*

pub struct UserDao<'a, T>
where
T: db::GenericClient
{
    pub new(executor: &'a T) -> Self {
        UserDao { executor }
    }
    
    pub async fn get_users(&self) -> AppResult<Vec<User>> {
        get_users()
            .bind(self.executor)
            .all()
            .await?
            .into_iter()
            .map(|item| item.to_user())
            .collect::<Vec<_>>()  
    }
}

服务层调用:

#[derive(Clone)]
pub struct AppState {
    pub pool: Arc<db::Pool>,
}

impl AppState {
    pub async fn new(config: Config) -> AppResult<Self> {
        let pool = Arc::new(create_pool(&config.storage.database_url));
        Ok(Self {
            pool
        })
    }
}

pub async fn list(state: &AppState) -> AppResult<Vec<User>> {
    let client = state.pool.get().await?;
    let user_dao = UserDao::new(&client);
    let users = user_dao.get_users().await?;
    Ok(users)
}

上述实现存在自定义的结构体和宏