用Rust实现一个简单的KV Server: day2

436 阅读5分钟

开启掘金成长之旅!这是我参与「掘金日新计划 · 12 月更文挑战」的第11天,点击查看活动详情

在上一章,我们用tokio实现了客户端和服务器的基本框架并设置了json格式的配置文件。

现在,我们参考Redis的命令:GETSETPUBLISHSUBSCRIBE,使用Protobuf来实现客户端与服务器之间的通信协议层。为了处理Protobuf,我们加入了post库。同时加入了dotenv库用于读取环境变量、 tracing库用于日志处理和bytes处理字节数据。

现在Cargo.toml应该是这样的:

[package]
name = "kv_server"
version = "0.1.0"
edition = "2021"

[dependencies]
anyhow = { version = "1.0.65" }
tokio = { version = "1.21.2", features = ["full"] }
serde = { version = "1.0.144", features = ["derive"] }
serde_json = { version = "1.0.85" }
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
bytes = "1.2.1"
prost = "0.11"
dotenv = "0.15.0"

[build-dependencies]
prost-build = "0.11"

[[bin]]
name = "server"
path = "src/bin/kv_server.rs"

[[bin]]
name = "client"
path = "src/bin/kv_client.rs"

Protobuf

在项目根目录下新建cmd.proto,这里就比较简单了,定义好一些数据结构。加入如下代码:

syntax = "proto3";

package cmd;

// 命令请求
message CmdRequest {
    oneof req_data {
        Get get = 1;
        Set set = 2;
        Publish publish = 3;
        Subscribe subscribe = 4;
        Unsubscribe unsubscribe = 5;
    }
}

// 服务器的响应
message CmdResponse {
    uint32 status = 1;
    string message = 2;
    bytes value = 3;
}

// 请求值命令
message Get {
    string key = 1;
}

// 存储值命令
message Set {
    string key = 1;
    bytes value = 2;
    uint32 expire = 3;
}

// 向Topic发布值命令
message Publish {
    string topic = 1;
    bytes value = 2;
}

// 订阅Topic命令
message Subscribe {
    string topic = 1;
}

// 取消订阅命令
message Unsubscribe {
    string topic = 1;
    uint32 id = 2;
}

在src目录下创建pb目录,在根目录下创建build.rs文件,加入如下代码:

fn main() {
    let mut conf = prost_build::Config::new();
    conf.bytes(&["."]);
    conf.type_attribute(".", "#[derive(PartialOrd)]");
    conf.out_dir("src/pb")
        .compile_protos(&["cmd.proto"], &["."])
        .expect("生成proto失败");
}

这里必须先创建pb文件目录,要不然打包proto就会报错:无法找到(pb)文件路径创建了pb目录,然后打包项目就会自动生成cmd.rs文件:

生成的cmd.rs文件,内容应该是这样的,其实就是把上面定义的数据结构转换为rust的struct

/// 命令请求
#[derive(PartialOrd)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct CmdRequest {
    #[prost(oneof="cmd_request::ReqData", tags="1, 2, 3, 4, 5")]
    pub req_data: ::core::option::Option<cmd_request::ReqData>,
}
/// Nested message and enum types in `CmdRequest`.
pub mod cmd_request {
    #[derive(PartialOrd)]
    #[derive(Clone, PartialEq, ::prost::Oneof)]
    pub enum ReqData {
        #[prost(message, tag="1")]
        Get(super::Get),
        #[prost(message, tag="2")]
        Set(super::Set),
        #[prost(message, tag="3")]
        Publish(super::Publish),
        #[prost(message, tag="4")]
        Subscribe(super::Subscribe),
        #[prost(message, tag="5")]
        Unsubscribe(super::Unsubscribe),
    }
}
/// 服务器的响应
#[derive(PartialOrd)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct CmdResponse {
    #[prost(uint32, tag="1")]
    pub status: u32,
    #[prost(string, tag="2")]
    pub message: ::prost::alloc::string::String,
    #[prost(bytes="bytes", tag="3")]
    pub value: ::prost::bytes::Bytes,
}
/// 请求值命令
#[derive(PartialOrd)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Get {
    #[prost(string, tag="1")]
    pub key: ::prost::alloc::string::String,
}
/// 存储值命令
#[derive(PartialOrd)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Set {
    #[prost(string, tag="1")]
    pub key: ::prost::alloc::string::String,
    #[prost(bytes="bytes", tag="2")]
    pub value: ::prost::bytes::Bytes,
    #[prost(uint32, tag="3")]
    pub expire: u32,
}
/// 向Topic发布值命令
#[derive(PartialOrd)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Publish {
    #[prost(string, tag="1")]
    pub topic: ::prost::alloc::string::String,
    #[prost(bytes="bytes", tag="2")]
    pub value: ::prost::bytes::Bytes,
}
/// 订阅Topic命令
#[derive(PartialOrd)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Subscribe {
    #[prost(string, tag="1")]
    pub topic: ::prost::alloc::string::String,
}
/// 取消订阅命令
#[derive(PartialOrd)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Unsubscribe {
    #[prost(string, tag="1")]
    pub topic: ::prost::alloc::string::String,
    #[prost(uint32, tag="2")]
    pub id: u32,
}

接着在src/pb目录下创建mod.rs文件。将结构体引入并实现命令行方法:

use bytes::Bytes;
use std::error::Error;
pub mod cmd;

use cmd::{
    cmd_request::ReqData, CmdRequest, CmdResponse, Get, Publish, Set, Subscribe, Unsubscribe,
};

impl CmdRequest {
    // GET命令
    pub fn get(key: impl Into<String>) -> Self {
        Self {
            req_data: Some(ReqData::Get(Get { key: key.into() })),
        }
    }

    // SET命令
    pub fn set(key: impl Into<String>, value: Bytes, expire: u32) -> Self {
        Self {
            req_data: Some(ReqData::Set(Set {
                key: key.into(),
                value,
                expire,
            })),
        }
    }

    // PUBLISH命令
    pub fn publish(topic: impl Into<String>, value: Bytes) -> Self {
        Self {
            req_data: Some(ReqData::Publish(Publish {
                topic: topic.into(),
                value,
            })),
        }
    }

    // 订阅命令
    pub fn subscribe(topic: impl Into<String>) -> Self {
        Self {
            req_data: Some(ReqData::Subscribe(Subscribe {
                topic: topic.into(),
            })),
        }
    }

    // 解除订阅命令
    pub fn unsubscribe(topic: impl Into<String>, id: u32) -> Self {
        Self {
            req_data: Some(ReqData::Unsubscribe(Unsubscribe {
                topic: topic.into(),
                id,
            })),
        }
    }
}

impl CmdResponse {
    pub fn new(status: u32, message: String, value: Bytes) -> Self {
        Self {
            status,
            message,
            value,
        }
    }
}

工具类

接着在src目录下创建utils.rs文件,目前先增加日志初始化函数:

use tracing_subscriber::{fmt, layer::SubscriberExt, registry, util::SubscriberInitExt, EnvFilter};

//日志记录
pub fn start_tracing() {
    //获取日志等级环境变量
    let env_filter = EnvFilter::try_from_env("RUST_LOG").unwrap_or_else(|_| EnvFilter::new("info"));
    // 输出到控制台中
    let formatting_layer = fmt::layer().pretty().with_writer(std::io::stderr);
    // 注册
    registry().with(env_filter).with(formatting_layer).init();
}

同时,新增.env文件,增加环境变量:RUST_LOG=debug

在 src/lib.rs 中,引入pb模块和utils模块:

mod pb;
pub use pb::cmd::*;

mod utils;
pub use utils::*;

客户端 & 服务器

我们使用tokio-util库的Frame里的LengthDelimitedCodec(根据长度进行编解码)对protobuf协议进行封包解包。在Cargo.toml里加入tokio-util依赖:

[dependencies]
......
futures = "0.3.24"
tokio-util = { version = "0.7.4", features = ["codec"] }
......

修改src/bin/kv_server.rs代码如下,注释都比较清晰了,就不多赘述了:

use anyhow::Result;
use bytes::Bytes;
use futures::{SinkExt, StreamExt};
use kv_server::{start_tracing, CmdRequest, CmdResponse, ServerConfig};
use prost::Message;
use std::error::Error;
use tokio::net::TcpListener;
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use tracing::info;

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
    //初始化环境变量
    dotenv::dotenv().ok();
    //启动日志
    start_tracing();
    //读取配置
    let server_conf = ServerConfig::load("conf/server.json")?;
    let addr = server_conf.listen_address.address;
    //连接ip地址
    let listener = TcpListener::bind(&addr).await?;
    println!("Listening on {} ......", addr);
    //循环监听端口
    loop {
        let (stream, addr) = listener.accept().await?;
        println!("Client: {:?} connected", addr);
        //新建Tokio进程处理
        tokio::spawn(async move {
            // 使用Frame的LengthDelimitedCodec进行编解码操作
            let mut stream = Framed::new(stream, LengthDelimitedCodec::new());
            while let Some(Ok(mut buf)) = stream.next().await {
                // 对客户端发来的protobuf请求命令进行拆包
                let cmd_req = CmdRequest::decode(&buf[..]).unwrap();
                info!("Receive a command: {:?}", cmd_req);

                buf.clear();

                // 对protobuf的请求响应进行封包,然后发送给客户端。
                let cmd_res = CmdResponse::new(200, "success".to_string(), Bytes::default());

                cmd_res.encode(&mut buf).unwrap();
                stream.send(buf.freeze()).await.unwrap();
                info!("Client {:?} disconnected", addr);
            }
        });
    }
}

同样,修改src/bin/kv_client.rs代码如下:

use anyhow::Result;
use bytes::BytesMut;
use futures::{SinkExt, StreamExt};
use kv_server::{start_tracing, ClientConfig, CmdRequest, CmdResponse};
use prost::Message;
use std::error::Error;
use tokio::net::TcpStream;
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use tracing::info;

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
    //初始化环境变量
    dotenv::dotenv().ok();
    //启动日志
    start_tracing();
    //读取配置
    let client_config = ClientConfig::load("conf/client.json")?;
    let addr = client_config.client_address.server_addr;
    let stream = TcpStream::connect(addr).await?;

    // 使用Frame的LengthDelimitedCodec进行编解码操作
    let mut stream = Framed::new(stream, LengthDelimitedCodec::new());
    let mut buf = BytesMut::new();

    // 创建GET命令
    let cmd_get = CmdRequest::get("mykey");
    //对信息进行编码
    cmd_get.encode(&mut buf).unwrap();
    // 发送GET命令
    stream.send(buf.freeze()).await.unwrap();
    info!("Send info successed!");
    // 接收服务器返回的响应
    while let Some(Ok(buf)) = stream.next().await {
        //对响应数据解码
        let cmd_res = CmdResponse::decode(&buf[..]).unwrap();
        info!("Receive a response: {:?}", cmd_res);
    }
    Ok(())
}

运行结果

我们打开二个终端,分别输入以下命令:

cargo run --bin server
cargo run --bin client

服务器执行结果:

Listening on 127.0.0.1:3000 ......
Client: 127.0.0.1:54464 connected
  2022-10-07T13:38:00.486338Z  INFO server: Receive a command: CmdRequest { req_data: Some(Get(Get { key: "mykey" })) }
    at src/bin/kv_server.rs:34

  2022-10-07T13:38:00.486572Z  INFO server: Client 127.0.0.1:54464 disconnected
    at src/bin/kv_server.rs:43

客户端执行结果:

 2022-10-07T13:38:00.486151Z  INFO client: Send info successed!
    at src/bin/kv_client.rs:31

  2022-10-07T13:38:00.486592Z  INFO client: Receive a response: CmdResponse { status: 200, message: "success", value: b"" }
    at src/bin/kv_client.rs:35

服务器和客户端都正常处理了收到的请求和响应。