rust:axum+sqlx实战学习笔记5

2,301 阅读5分钟

携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第5天,点击查看活动详情

上一章,我们学习了sqlx的使用,再加上一开始对axum的学习,如今万事俱备,只欠东风。现在,我们正式开始使用axum和sqlx搭建一个简单的web后端服务。

首先是mian.rs文件:

use dotenv;
use linken_server::common::{utils::shutdown_signal, BANNER};
use linken_server::db;
use linken_server::routers::init::routers;
use sqlx;
use std::{net::SocketAddr, str::FromStr};
use tracing_subscriber::{fmt, layer::SubscriberExt, registry, util::SubscriberInitExt, EnvFilter};

#[tokio::main]
async fn main() -> Result<(), sqlx::Error> {
    //初始化环境变量
    dotenv::dotenv().ok();
    //启动日志
    start_tracing();
    eprintln!("{}", BANNER);
    tracing::info!("服务启动");
    //连接数据库
    db::init_db_pool().await?;
    //构建路由
    let app = routers();
    // //监听端口
    let addr = SocketAddr::from_str("127.0.0.1:3000").unwrap();
    tracing::debug!("listening on {}", addr);
    //绑定ip
    axum::Server::bind(&addr)
        .serve(app.into_make_service())
        .with_graceful_shutdown(shutdown_signal())
        .await
        .unwrap();
    Ok(())
}

//日志记录
fn start_tracing() {
    //获取日志等级环境变量
    let env_filter = EnvFilter::try_from_env("RUST_LOG").unwrap_or_else(|_| EnvFilter::new("info"));
    // 输出到控制台中
    let formatting_layer = fmt::layer().pretty().with_writer(std::io::stderr);
    // 注册
    registry().with(env_filter).with(formatting_layer).init();
}

在mian.rs中启动日志服务,并监听127.0.0.1:3000地址。把路由提取到routers模块,数据库放到db模块,一些通用函数和数据放在common模块。

接下来,先处理数据库部分,在db模块中,建立数据库连接池,并声明公共函数:

use dotenv;
use once_cell::sync::OnceCell;
use sqlx::{mysql::MySqlPoolOptions, Error, MySqlPool};

static MYSQL_POOL: OnceCell<MySqlPool> = OnceCell::new();
//建立mysql连接
pub async fn init_db_pool() -> Result<(), Error> {
    let key = "DATABASE_URL";
    let db_url = dotenv::var(key).unwrap();

    let pool = MySqlPoolOptions::new()
        .max_connections(5)
        .connect(&db_url)
        .await?;
    assert!(MYSQL_POOL.set(pool).is_ok());
    Ok(())
}
//获取数据库
pub fn get_pool() -> Option<&'static MySqlPool> {
    MYSQL_POOL.get()
}

然后把前几章所写的RespVO响应数据结构体和jwt认证分别作为common模块和middleware模块。这样就搭建好了整体框架。

下面,重点放在routers模块,首先看routers的文件结构:

1660290020768.png

入口为init.rs文件,里面对路由进行注册。

use super::auth;
use super::todo;
use crate::middleware::auth::Claims;
use axum::{middleware::from_extractor, routing::MethodRouter, Router};

//构建路由公共方法
pub fn handle_router(path: &str, method_router: MethodRouter) -> Router {
    Router::new().route(path, method_router)
}

//api
pub fn routers() -> Router {
    auth_init_router().merge(init_router())
}

//需要权限认证的路由
fn auth_init_router() -> Router {
    let app = Router::new()
        .merge(auth::login()) //登录
        .merge(auth::update_user()) //更新用户信息
        .merge(auth::get_user_list()) //查询用户信息列表
        .merge(todo::create_todo()) //创建todoList
        .merge(todo::handle_todo_list()) //处理todoList
        .merge(todo::create_todo_item()) //创建TodoItem
        .merge(todo::handle_todo_items()) //处理TodoItems
        .merge(todo::handle_todo_item()) //处理TodoItem
        .layer(from_extractor::<Claims>());
    return app;
}

//不需要权限认证的路由
fn init_router() -> Router {
    let app = Router::new().merge(auth::register()); //注册
    return app;
}

这里用到了路由的merge方法,这个方法其实是把多个路由进行合并,有兴趣的同学可以查看官方文档获取详细说明。同时在路由后面使用了前面写好的权限认证中间件。

可以看到,其实真正的路由是在各自的模块中导出的。那我们先看auth路由模块。 在mod.rs中引入文件并导出路由:

mod api;
mod handler;
mod dto;
mod servers;

pub use api::{get_user_list, login, register, update_user};

在api.rs中进行路由注册:

use super::handler::{
    authorize, get_user_list as get_user_lists, login as user_login,
    update_user as update_user_info,
};
use crate::routers::init::handle_router;

use axum::{
    routing::{get, post},
    Router,
};

//注册
pub fn register() -> Router {
    //构建注册路由
    handle_router("/register", post(authorize))
}

//登录
pub fn login() -> Router {
    //构建登录路由
    handle_router("/login", post(user_login))
}

//更新
pub fn update_user() -> Router {
    //构建登录路由
    handle_router("/user/:id", post(update_user_info))
}

//查询用户信息列表
pub fn get_user_list() -> Router {
    //构建登录路由
    handle_router("/user-list", get(get_user_lists))
}

处理路由逻辑的函数则放在handler文件:

use super::servers;
use super::dto::{AuthPayload, AuthToken, LoginPayload};
use crate::common::response::RespVO;
use crate::jwt::KEYS;
use crate::middleware::auth::Claims;
use axum::{extract::Path, response::IntoResponse, Json};
use jsonwebtoken::{encode, Header};

//注册
pub async fn authorize(Json(payload): Json<AuthPayload>) -> impl IntoResponse {
    // 检查用户名
    if payload.username.is_empty() {
        return Json(RespVO::<AuthToken>::from_error("用户名不能为空!"));
    } else if payload.password.is_empty() {
        return Json(RespVO::<AuthToken>::from_error("密码不能为空!"));
    }

    let claims = AuthPayload {
        username: payload.username.to_owned(),
        password: payload.password.to_owned(),
        phone: payload.phone.to_owned(),
        sex: payload.sex.to_owned(),
        email: payload.email.to_owned(),
        avatar: payload.avatar.to_owned(),
        role: payload.role.clone(),
        status: payload.status.clone(),
        create_time: None,
        // token到期时间
        // exp: Some(2000000000), // May 2033
    };
    //创建token, Create the authorization token
    let token = encode(&Header::default(), &claims, &KEYS.encoding)
        .map_err(|_| Json(RespVO::<AuthToken>::from_error("token创建失败!")))
        .unwrap();
    let result = servers::create(payload.clone()).await;
    match result {
        Ok(res) => {
            if res == 1 {
                //返回token, Send the authorized token
                let arg = AuthToken::new(token);
                Json(RespVO::<AuthToken>::from_result(&arg))
            } else {
                Json(RespVO::<AuthToken>::from_error("写入数据库失败!"))
            }
        }
        Err(err) => {
            tracing::error!("authorize: {:?}", err);
            let info = err.to_string();
            Json(RespVO::<AuthToken>::from_error(&info))
        }
    }
}

//登录
pub async fn login(Json(body): Json<LoginPayload>, claims: Claims) -> impl IntoResponse {
    tracing::info!("登录token信息:{:?}", claims);
    let result = servers::show(&body.username).await;
    match result {
        Ok(res) => {
            if body.username != res.username {
                return Json(RespVO::<AuthPayload>::from_error("用户错误!"));
            } else if body.password != res.password {
                return Json(RespVO::<AuthPayload>::from_error("密码错误!"));
            }
            Json(RespVO::<AuthPayload>::from_result(&res))
        }
        Err(err) => {
            tracing::error!("login: {:?}", err);
            let info = err.to_string();
            Json(RespVO::<AuthPayload>::from_error(&info))
        }
    }
}

//更新用户信息
pub async fn update_user(Path(id): Path<i32>, Json(body): Json<AuthPayload>) -> impl IntoResponse {
    let result = servers::update(id, body).await;
    match result {
        Ok(res) => {
            if res == 0 {
                let msg = "更新成功!".to_string();
                Json(RespVO::<String>::from_result(&msg))
            } else {
                Json(RespVO::<String>::from_error("数据库没有该用户!"))
            }
        }
        Err(err) => {
            tracing::error!("update_user: {:?}", err);
            let info = err.to_string();
            Json(RespVO::<String>::from_error(&info))
        }
    }
}

//查询用户信息列表
pub async fn get_user_list() -> impl IntoResponse {
    let result = servers::list().await;
    match result {
        Ok(res) => Json(RespVO::<Vec<AuthPayload>>::from_result(&res)),
        Err(err) => {
            tracing::error!("get_user_list: {:?}", err);
            let info = err.to_string();
            Json(RespVO::<Vec<AuthPayload>>::from_error(&info))
        }
    }
}

同时数据库操作方法提取到servers.rs文件:

use super::dto::AuthPayload;
use crate::db;
use sqlx::{self, Error};
use std::vec::Vec;

/**
 * 测试接口: 查 列表
 */
pub async fn list() -> Result<Vec<AuthPayload>, Error> {
    let pool = db::get_pool().unwrap();
    let sql =
        "select username, password, phone, email, sex, role, status, avatar, create_time from auth";
    let list = sqlx::query_as::<_, AuthPayload>(sql)
        .fetch_all(pool)
        .await?;
    Ok(list)
}

/**
 * 测试接口: 增
 */
pub async fn create(user: AuthPayload) -> Result<u64, Error> {
    let sql = "insert into auth(username, password, sex, phone, email, avatar, role, status) values (?, ?, ?, ?, ?, ?, ?, ?)";
    let pool = db::get_pool().unwrap();
    let affect_rows = sqlx::query(sql)
        .bind(&user.username)
        .bind(&user.password)
        .bind(&user.sex)
        .bind(&user.phone)
        .bind(&user.email)
        .bind(&user.avatar)
        .bind(&user.role)
        .bind(&user.status)
        .execute(pool)
        .await?
        .rows_affected();
    Ok(affect_rows)
}

/**
 * 测试接口: 查
 */
pub async fn show(username: &str) -> Result<AuthPayload, Error> {
    let sql = "select username,password,phone,email,sex,role,status,avatar,create_time from auth where username = ?";
    let pool = db::get_pool().unwrap();
    let res = sqlx::query_as::<_, AuthPayload>(sql)
        .bind(username)
        .fetch_one(pool)
        .await?;
    Ok(res)
}

/**
 * 测试接口: 改
 */
pub async fn update(id: i32, user: AuthPayload) -> Result<u64, Error> {
    let pool = db::get_pool().unwrap();
    let sql="update auth set username = ?, password = ?, phone = ?, email = ?, sex = ?, role = ? , status = ?, avatar = ? where id = ?";
    let pg_done = sqlx::query(sql)
        .bind(&user.username)
        .bind(&user.password)
        .bind(&user.phone)
        .bind(&user.email)
        .bind(&user.sex)
        .bind(&user.role)
        .bind(&user.status)
        .bind(&user.avatar)
        .bind(id.clone())
        .execute(pool)
        .await?
        .rows_affected();
    Ok(pg_done)
}

最后,所用到的数据结构放在dto.rs文件:

use crate::common::types::{Role, Sex, UserState};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::{Decode, Encode, FromRow, Type};

//token结构体
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AuthToken {
    pub access_token: String,
    pub token_type: String,
}

impl AuthToken {
    pub fn new(access_token: String) -> Self {
        Self {
            access_token,
            token_type: "Bearer".to_string(),
        }
    }
}

//注册请求体
#[derive(Debug, Clone, Deserialize, Serialize, Decode, Encode, Type, FromRow)]
pub struct AuthPayload {
    pub username: String,
    pub password: String,
    pub phone: String,
    pub sex: Sex,
    pub email: String,
    pub avatar: Option<String>,
    pub role: Option<Role>,
    pub status: Option<UserState>,
    pub create_time: Option<DateTime<Utc>>,
}

//注册请求体
#[derive(Debug, Clone, Deserialize, Serialize, Decode, Encode, Type, FromRow)]
pub struct LoginPayload {
    pub username: String,
    pub password: String,
}

另外,有些通用的结构,可以提取出来放进common模块。

就这样,一个简单的web服务诞生了。其中大部分的难点,其实都在前面的章节进行了一一介绍,这里不过是对前面知识进行归纳运用。