rust:axum+sqlx实战学习笔记8

870 阅读6分钟

经过最近几周的忙碌,现在终于有空把最后的坑给填上了。也就是本次实战的核心内容-博客文章的增删改查。

那么,废话少说,show me the code。

首先,还是api.rs文件,写上增删改查需要的路由和对应的handler函数。

use super::handlers::{
    add_article, delete_articles, get_article_item, get_articles, get_articles_by_cate,
    get_articles_by_read, get_articles_by_tag, update_articles, update_read_num,
};
use axum::{
    routing::{delete, get, post, put},
    Router,
};

//创建一个文章handler
pub fn handler_articles() -> Router {
    //构建注册路由
    Router::new()
        .route("/page", get(get_articles))
        .route("/top_read", get(get_articles_by_read))
        .route("/add", post(add_article))
        .route("/page_by_tag", get(get_articles_by_tag))
        .route("/page_by_category", get(get_articles_by_cate))
        .route("/update", put(update_articles))
        .route("/delete/:id", delete(delete_articles))
        .route("/detail", get(get_article_item))
        .route("/update_read_num", put(update_read_num))
}

然后就是handlers.rs文件里面把用到的handler方法一一补充完整:

use super::dto::{AddArticle, ArticleList, ArticleReadList};
use super::servers;
use crate::common::{
    response::RespVO,
    types::{PageInfo, PageResult, RequestId},
};
use crate::routers::{category, tag};
use axum::{
    extract::{Path, Query},
    http::StatusCode,
    response::IntoResponse,
    Json,
};
use std::collections::HashMap;

//查询文章列表
pub async fn get_articles(Query(req): Query<PageInfo>) -> impl IntoResponse {
    let result = servers::list(req.clone()).await;
    let total = servers::total().await;
    match result {
        Ok(res) => {
            let response = PageResult {
                list: res,
                total: total.unwrap(),
                pageSize: req.pageSize.unwrap_or(10),
                pageNum: req.pageNum.unwrap_or(1),
            };
            Json(RespVO::<PageResult<ArticleList>>::from_result(&response))
        }
        Err(err) => {
            tracing::error!("get_articles_lists: {:?}", err);
            let info = err.to_string();
            let code = StatusCode::NOT_FOUND;
            Json(RespVO::<PageResult<ArticleList>>::from_error_info(
                code, &info,
            ))
        }
    }
}

//查询单篇文章
pub async fn get_article_item(Query(args): Query<HashMap<String, i32>>) -> impl IntoResponse {
    let result = servers::item(args.get("id").unwrap()).await;
    match result {
        Ok(res) => Json(RespVO::<ArticleList>::from_result(&res)),
        Err(err) => {
            tracing::error!("get_article_item: {:?}", err);
            let info = err.to_string();
            let code = StatusCode::NOT_FOUND;
            Json(RespVO::<ArticleList>::from_error_info(code, &info))
        }
    }
}

//通过标签查询文章列表
pub async fn get_articles_by_tag(Query(req): Query<PageInfo>) -> impl IntoResponse {
    let result = servers::list_by_tag(req.clone()).await;
    let total = servers::total_tag(req.clone()).await;
    match result {
        Ok(res) => {
            let response = PageResult {
                list: res,
                total: total.unwrap(),
                pageSize: req.pageSize.unwrap_or(10),
                pageNum: req.pageNum.unwrap_or(1),
            };
            Json(RespVO::<PageResult<ArticleList>>::from_result(&response))
        }
        Err(err) => {
            tracing::error!("get_articles_lists: {:?}", err);
            let info = err.to_string();
            let code = StatusCode::NOT_FOUND;
            Json(RespVO::<PageResult<ArticleList>>::from_error_info(
                code, &info,
            ))
        }
    }
}

//通过分类查询文章列表
pub async fn get_articles_by_cate(Query(req): Query<PageInfo>) -> impl IntoResponse {
    let result = servers::list_by_cate(req.clone()).await;
    let total = servers::total_cate(req.clone()).await;
    match result {
        Ok(res) => {
            let response = PageResult {
                list: res,
                total: total.unwrap(),
                pageSize: req.pageSize.unwrap_or(10),
                pageNum: req.pageNum.unwrap_or(1),
            };
            Json(RespVO::<PageResult<ArticleList>>::from_result(&response))
        }

        Err(err) => {
            tracing::error!("get_articles_lists: {:?}", err);
            let info = err.to_string();
            let code = StatusCode::NOT_FOUND;
            Json(RespVO::<PageResult<ArticleList>>::from_error_info(
                code, &info,
            ))
        }
    }
}

//通过阅读量查询文章列表
pub async fn get_articles_by_read(Query(args): Query<HashMap<String, i32>>) -> impl IntoResponse {
    let result = servers::list_by_read(args.get("count").unwrap()).await;
    match result {
        Ok(res) => Json(RespVO::<Vec<ArticleReadList>>::from_result(&res)),
        Err(err) => {
            tracing::error!("get_articles_lists: {:?}", err);
            let info = err.to_string();
            let code = StatusCode::NOT_FOUND;
            Json(RespVO::<Vec<ArticleReadList>>::from_error_info(code, &info))
        }
    }
}

//新增文章表
pub async fn add_article(Json(req): Json<AddArticle>) -> impl IntoResponse {
    let article_id = servers::create(req.clone()).await;

    match article_id {
        Ok(id) => {
            if id == 0 {
                let msg = "新增失败!";
                Json(RespVO::<String>::from_error(msg))
            } else {
                //分类
                for category_id in req.categoryIds.iter() {
                    let _ = category::create_article_category(&id, category_id).await;
                }
                //标签
                for tag_id in req.tagIds.iter() {
                    let _ = tag::create_article_tag(&id, tag_id).await;
                }

                let msg = "新增成功!".to_string();
                Json(RespVO::<String>::from_result(&msg))
            }
        }
        Err(err) => {
            tracing::error!("add_article: {:?}", err);
            let info = err.to_string();
            let code = StatusCode::BAD_REQUEST;
            Json(RespVO::<String>::from_error_info(code, &info))
        }
    }
}

//更新文章表
pub async fn update_articles(Json(req): Json<AddArticle>) -> impl IntoResponse {
    let mut info = None;
    //分类
    let old_category = category::get_cate_by_id(&req.id.unwrap()).await;
    match old_category {
        Ok(res) => {
            let old_category_id: Vec<u32> = res.into_iter().map(|item| item.category_id).collect();
            let new_category_id: Vec<u32> = req.categoryIds.clone().into_iter().collect();
            //旧分类
            for old_id in old_category_id.iter() {
                let res = new_category_id.iter().find(|id| *id == old_id);
                if let None = res {
                    let _ = category::delete_article_category(old_id).await;
                }
            }
            //新分类
            for new_id in new_category_id.iter() {
                let res = old_category_id.iter().find(|id| *id == new_id);
                if let None = res {
                    let had_deleted_category = category::get_by_cate_id(new_id).await;
                    match had_deleted_category {
                        Ok(deleted_category) => {
                            if let Some(item) = deleted_category {
                                let _ = category::update_article_category(&item.category_id).await;
                            } else {
                                let _ = category::create_article_category(&req.id.unwrap(), new_id)
                                    .await;
                            }
                        }
                        Err(err) => {
                            tracing::error!("update_articles: {:?}", err);
                            info = Some(err.to_string());
                        }
                    }
                }
            }
        }
        Err(err) => {
            tracing::error!("update_articles: {:?}", err);
            info = Some(err.to_string());
        }
    }
    //标签
    let old_tag = tag::get_tag_by_id(&req.id.unwrap()).await;
    match old_tag {
        Ok(res) => {
            let old_tag_id: Vec<u32> = res.into_iter().map(|item| item.tag_id).collect();
            let new_tag_id: Vec<u32> = req.tagIds.clone().into_iter().collect();
            //旧标签
            for old_id in old_tag_id.iter() {
                let res = new_tag_id.iter().find(|id| *id == old_id);
                if let None = res {
                    let _ = tag::delete_article_tag(old_id).await;
                }
            }
            //新标签
            for new_id in new_tag_id.iter() {
                let res = old_tag_id.iter().find(|id| *id == new_id);
                if let None = res {
                    let had_deleted_tag = tag::get_by_tag_id(new_id).await;
                    match had_deleted_tag {
                        Ok(deleted_tag) => {
                            if let Some(item) = deleted_tag {
                                let _ = tag::update_article_tag(&item.tag_id).await;
                            } else {
                                let _ = tag::create_article_tag(&req.id.unwrap(), new_id).await;
                            }
                        }
                        Err(err) => {
                            tracing::error!("update_articles: {:?}", err);
                            info = Some(err.to_string());
                        }
                    }
                }
            }
        }
        Err(err) => {
            tracing::error!("update_articles: {:?}", err);
            info = Some(err.to_string());
        }
    }

    if let Some(msg) = info {
        let code = StatusCode::BAD_REQUEST;
        Json(RespVO::<String>::from_error_info(code, &msg))
    } else {
        let result = servers::update(req).await;
        match result {
            Ok(_res) => {
                let msg = "更新成功!".to_string();
                Json(RespVO::<String>::from_result(&msg))
            }
            Err(err) => {
                tracing::error!("update_articles: {:?}", err);
                let info = err.to_string();
                let code = StatusCode::BAD_REQUEST;
                Json(RespVO::<String>::from_error_info(code, &info))
            }
        }
    }
}

//更新文章阅读量
pub async fn update_read_num(Json(req): Json<RequestId>) -> impl IntoResponse {
    let result = servers::update_read(req.id).await;
    match result {
        Ok(_res) => {
            let msg = "更新成功!".to_string();
            Json(RespVO::<String>::from_result(&msg))
        }
        Err(err) => {
            tracing::error!("update_articles: {:?}", err);
            let info = err.to_string();
            let code = StatusCode::BAD_REQUEST;
            Json(RespVO::<String>::from_error_info(code, &info))
        }
    }
}

//删除文章表
pub async fn delete_articles(Path(id): Path<u32>) -> impl IntoResponse {
    let result = servers::delete(id).await;
    match result {
        Ok(_res) => {
            let msg = "删除成功!".to_string();
            Json(RespVO::<String>::from_result(&msg))
        }
        Err(err) => {
            tracing::error!("delete_articles: {:?}", err);
            let info = err.to_string();
            let code = StatusCode::BAD_REQUEST;
            Json(RespVO::<String>::from_error_info(code, &info))
        }
    }
}

接着在servers.rs中对数据库进行处理:

use super::dto::{AddArticle, ArticleDetail, ArticleList, ArticleReadList};
use crate::common::types::{PageInfo, RequestId, TotalResponse};
use crate::db;
use sqlx::{self, Error};

/**
 * 测试接口: 查文章列表
 */
pub async fn list(req: PageInfo) -> Result<Vec<ArticleList>, Error> {
    let page_num = req.pageNum.unwrap_or(1);
    let page_size = req.pageSize.unwrap_or(10);
    let limit = (page_num - 1) * page_size;
    let pool = db::get_pool().unwrap();
    let sql = "select
	 a.id,
	a.article_title,
	a.article_text ,
	a.poster,
	a.summary,
	a.create_time,
	 a.read_num,
	GROUP_CONCAT(distinct c.id) as category_ids,
	GROUP_CONCAT(distinct t.id) as tag_ids
from
	article a
left join auth on
	a.author_id = auth.id
left join article_category a_c on
	a.id = a_c.article_id
left join category c on
	a_c.category_id = c.id
left join article_tag a_t on
	a.id = a_t.article_id
left join tag t on
	a_t.tag_id = t.id
group by
	a.id
order by
	a.create_time 
limit ?,?;";
    let list = sqlx::query_as::<_, ArticleList>(sql)
        .bind(limit)
        .bind(page_size)
        .fetch_all(pool)
        .await?;
    Ok(list)
}

/**
 * 测试接口: 查单个文章
 */
pub async fn item(id: &i32) -> Result<ArticleList, Error> {
    let pool = db::get_pool().unwrap();
    let sql = "select
	a.*,
	GROUP_CONCAT(distinct c.id) as category_ids,
	GROUP_CONCAT(distinct t.id) as tag_ids
from
	article a
left join auth on
	a.author_id = auth.id
left join article_category a_c on
	a.id = a_c.article_id 
left join category c on
	a_c.category_id = c.id
left join article_tag a_t on
	a.id = a_t.article_id
left join tag t on
	a_t.tag_id = t.id
where a_c.deleted = 0 and a_t.deleted = 0
group by
	a.id
having
	a.id = ?";
    let list = sqlx::query_as::<_, ArticleList>(sql)
        .bind(id)
        .fetch_one(pool)
        .await?;
    Ok(list)
}

/**
 * 测试接口: 标签查文章
 */
pub async fn list_by_tag(req: PageInfo) -> Result<Vec<ArticleList>, Error> {
    let page_num = req.pageNum.unwrap_or(1);
    let page_size = req.pageSize.unwrap_or(10);
    let limit = (page_num - 1) * page_size;
    let pool = db::get_pool().unwrap();
    let sql = "select
	a.id,
	a.article_title,
	a.article_text ,
	a.poster,
	a.summary,
	a.create_time,
	 a.read_num,
	GROUP_CONCAT(distinct c.id) as category_ids,
	GROUP_CONCAT(distinct t.id) as tag_ids
from article a
left join auth on
	a.author_id = auth.id
left join article_category a_c on
	a.id = a_c.article_id
left join category c on
	a_c.category_id = c.id
left join article_tag a_t on
	a.id = a_t.article_id
left join tag t on
	a_t.tag_id = t.id
group by
	a.id
having
	a.id = any (
	select
		article_id
	from
		article_tag
	where
		tag_id = any(
		select
			id
		from
			tag
		where
			tag_name = ?))
order by
	a.create_time
limit ?,?;";
    let list = sqlx::query_as::<_, ArticleList>(sql)
        .bind(req.keyword)
        .bind(limit)
        .bind(page_size)
        .fetch_all(pool)
        .await?;
    Ok(list)
}

/**
 * 测试接口: 分类查文章
 */
pub async fn list_by_cate(req: PageInfo) -> Result<Vec<ArticleList>, Error> {
    let page_num = req.pageNum.unwrap_or(1);
    let page_size = req.pageSize.unwrap_or(10);
    let limit = (page_num - 1) * page_size;
    let pool = db::get_pool().unwrap();
    let sql = "SELECT 
	a.id,
	a.article_title,
	a.article_text ,
	a.poster,
	a.summary,
	a.create_time,
	 a.read_num,
	GROUP_CONCAT(distinct c.id) as category_ids,
	GROUP_CONCAT(distinct t.id) as tag_ids
from
	article a
left join auth on
	a.author_id = auth.id
left join article_category a_c on
	a.id = a_c.article_id
left join category c on
	a_c.category_id = c.id
left join article_tag a_t on
	a.id = a_t.article_id
left join tag t on
	a_t.tag_id = t.id
group by
	a.id
having
	a.id = any (
	select
		article_id
	from
		article_category
	where
		category_id = any (
		select
			id
		from
			category
		where
			category_name = ?))
order by
	a.create_time
limit ?,?;";
    let list = sqlx::query_as::<_, ArticleList>(sql)
        .bind(req.keyword)
        .bind(limit)
        .bind(page_size)
        .fetch_all(pool)
        .await?;

    Ok(list)
}

/**
 * 测试接口: 查阅读量排名列表
 */
pub async fn list_by_read(count: &i32) -> Result<Vec<ArticleReadList>, Error> {
    let pool = db::get_pool().unwrap();
    let sql = "SELECT id, article_title, read_num, poster FROM article WHERE deleted = 0 ORDER BY `read_num` DESC LIMIT ?;";
    let list = sqlx::query_as::<_, ArticleReadList>(sql)
        .bind(count)
        .fetch_all(pool)
        .await?;
    Ok(list)
}

/**
 * 测试接口: 新增
 */
pub async fn create(detail: AddArticle) -> Result<u32, Error> {
    let sql = "INSERT IGNORE INTO article (article_title, article_text, summary, author_id, poster) VALUES (?, ?, ?, ?, ?);";
    let pool = db::get_pool().unwrap();
    let rows_affected = sqlx::query(sql)
        .bind(detail.articleTitle)
        .bind(detail.articleText)
        .bind(detail.summary)
        .bind(detail.authorId)
        .bind(detail.poster)
        .execute(pool)
        .await?
        .rows_affected();
    if rows_affected == 1 {
        let sql = "select max(id) as id from article;";
        let res = sqlx::query_as::<_, RequestId>(sql).fetch_one(pool).await?;
        Ok(res.id)
    } else {
        Ok(0)
    }
}

/**
 * 测试接口: 更新
 */
pub async fn update(detail: AddArticle) -> Result<u64, Error> {
    let sql = "UPDATE article SET article_title = ?, article_text = ?, summary = ?, poster = ? WHERE id = ?";
    let pool = db::get_pool().unwrap();
    let affect_rows = sqlx::query(sql)
        .bind(detail.articleTitle)
        .bind(detail.articleText)
        .bind(detail.summary)
        .bind(detail.poster)
        .bind(detail.id)
        .execute(pool)
        .await?
        .rows_affected();
    Ok(affect_rows)
}

/**
 * 测试接口: 新增阅读量
 */
pub async fn update_read(id: u32) -> Result<u64, Error> {
    let sql = "UPDATE article SET read_num = read_num + 1 WHERE id = ?";
    let pool = db::get_pool().unwrap();
    let affect_rows = sqlx::query(sql)
        .bind(id)
        .execute(pool)
        .await?
        .rows_affected();
    Ok(affect_rows)
}
/**
 * 测试接口: 删除
 */
pub async fn delete(id: u32) -> Result<u64, Error> {
    let pool = db::get_pool().unwrap();
    let sql = "delete from article where id = ?";
    let pg_done = sqlx::query(sql)
        .bind(id)
        .execute(pool)
        .await?
        .rows_affected();
    Ok(pg_done)
}

/**
 * 测试接口: 查文章总数
 */
pub async fn total() -> Result<i64, Error> {
    let pool = db::get_pool().unwrap();
    let sql = "SELECT COUNT(*) AS total FROM article;";
    let total_response = sqlx::query_as::<_, TotalResponse>(sql)
        .fetch_one(pool)
        .await?;
    Ok(total_response.total)
}

/**
 * 测试接口: 查分类对应文章的数量
 */
pub async fn total_cate(req: PageInfo) -> Result<i64, Error> {
    let pool = db::get_pool().unwrap();
    let sql = "SELECT COUNT(*) AS total from
	article
where
	article.id = any (
	select
		article_id
	from
		article_category
	where
		category_id = any (
		select
			id
		from
			category
		where
			category_name = ?));";
    let total_response = sqlx::query_as::<_, TotalResponse>(sql)
        .bind(req.keyword)
        .fetch_one(pool)
        .await?;
    Ok(total_response.total)
}

/**
 * 测试接口: 查标签对应文章的数量
 */
pub async fn total_tag(req: PageInfo) -> Result<i64, Error> {
    let pool = db::get_pool().unwrap();
    let sql = "SELECT COUNT(*) AS total from
	article
where
	article.id = any (
	select
		article_id
	from
		article_tag
	where
		tag_id = any(
		select
			id
		from
			tag
		where
			tag_name = ?));";
    let total_response = sqlx::query_as::<_, TotalResponse>(sql)
        .bind(req.keyword)
        .fetch_one(pool)
        .await?;
    Ok(total_response.total)
}

然后就是把数据机构提取到dto.rs:

use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
/// 文章列表模型
#[derive(Debug, Clone, Deserialize, Serialize, FromRow)]
pub struct ArticleList {
    pub id: Option<u32>,
    pub article_title: String,
    pub article_text: String,
    pub summary: String,
    pub poster: String,
    pub create_time: Option<DateTime<Utc>>,
    pub category_ids: Option<String>,
    pub tag_ids: Option<String>,
    pub read_num: Option<u64>,
}

/// 文章详情模型
#[derive(Debug, Clone, Deserialize, Serialize, FromRow)]
pub struct ArticleDetail {
    pub article_title: String,
    pub article_text: String,
    pub summary: String,
    pub poster: String,
    pub author_id: Option<u32>,
}

/// 文章阅读量模型
#[derive(Debug, Clone, Deserialize, Serialize, FromRow)]
pub struct ArticleReadList {
    pub id: Option<u32>,
    pub article_title: String,
    pub read_num: u64,
    pub poster: String,
}

/// 新增文章
#[derive(Debug, Clone, Deserialize, Serialize, FromRow)]
pub struct AddArticle {
    pub id: Option<u32>,
    pub articleTitle: String,
    pub articleText: String,
    pub summary: String,
    pub poster: String,
    pub authorId: Option<u32>,
    pub tagIds: Vec<u32>,
    pub categoryIds: Vec<u32>,
}

最后不要忘记在mod.rs把各个模块进行导出:

mod api;
mod dto;
mod handlers;
mod servers;

pub use api::handler_articles;

就这样,一个完整的博客后台基本成型了。我们这个实战系列也告一段落。这里除了接口的增删改查之外,真正的难点还是在于对axum的理解运用上,在一开始的章节,我们就对axum进行了简单介绍。当然,axum的内容其实远不止这么点东西。后续如果有时间,我会继续探索axum的使用方法。