从零到一: 用Go语言搭建简易RPC框架并实践 (二)

134 阅读7分钟

在上一篇文章中,我们简单实现了一个网络通信服务。在本篇文章中,我们将在服务端实现服务注册处理客户端请求的功能。

前一篇传送门: 从零到一: 用Go语言搭建简易RPC框架并实践 (一)

项目结构

image.png

服务端代码改造

  • server.go

首先实现服务注册代码。假设我们实现了一个排序服务SortServer。这个服务实现了一个Sort方法。并且这个方法满足以下条件 (这里沿用了net/rpc源码中可以作为RPC方法的限制条件):

  1. the method’s type is exported. – 方法所属类型是导出的。
  2. the method is exported. – 方式是导出的。
  3. the method has two arguments, both exported (or builtin) types. – 两个入参,均为导出或内置类型。
  4. the method’s second argument is a pointer. – 第二个入参必须是一个指针。
  5. the method has return type error. – 返回值为 error 类型。

我们就认为这个Sort方法作为SortServer服务对外提供的方法

image.png

服务注册代码如下 image.png

image.png

image.png

首先定义了一个service结构体,用于记录SortServer的相关信息。name字段存储服务名,这里是"SortServer";rcvr记录SortServer结构体指针的反射值(reflect.Value)信息;typ记录SortServer结构体指针的反射类型(reflect.Type)信息;method记录SortServer可作为RPC的函数方法,比如这里的Sort方法。

method哈希表的Value定义为一个methodType结构体。里面定义了调用RPC方法需要的所有信息,包括方法本身method,方法的第一个参数的类型ArgType,方法的第二个参数类型ReplyType。之后利用反射提供的相关方法,可以很轻松的实现对应方法名的函数调用。

先前的Server结构体没有成员变量,本次我们为它新增一个serviceMap字段,用于注册服务。在(*Server).register中,定义了一个registerMethods方法。将服务名为s.name的所有PRC方法,注册到了server.serviceMap当中,从而实现了服务注册功能。

接着修改ServeConn代码如下。之前只是简单解析了客户端传给服务端的消息,现在注册的serviceMap当中,找到客户端想要调用的服务方法,调用函数,传入客户端的函数请求参数,得到函数的执行结果,组装成response写入连接当中,返回给客户端。

image.png

  • server_main.go

实现服务注册功能后,需要修改server_main.go文件,在处理连接请求前将定义的服务注册好,等待客户端调用相关方法。 image.png

完整源码

  • server.go
package gopher_rpc

import (
    "bufio"
    "encoding/json"
    "fmt"
    "io"
    "reflect"
    "strings"
    "sync"
)

type Server struct {
    serviceMap sync.Map // map[string]*service 注册服务时新增的字段
}

func NewServer() *Server {
    return &Server{}
}

var DefaultServer = NewServer()

type Request struct {
    Method string      `json:"method"`
    Args   interface{} `json:"args"`
}

func (server *Server) ServeConn(conn io.ReadWriteCloser) {
    defer conn.Close()
    reader := bufio.NewReader(conn)
    for {
       message, err := reader.ReadString('\n')
       if err != nil {
          if err == io.EOF {
             break
          }
          fmt.Printf("读取数据时出错: %v\n", err)
          return
       }

       var request Request
       err = json.Unmarshal([]byte(message), &request)
       if err != nil {
          fmt.Printf("反序列化 JSON 数据时出错: %v\n", err)
          return
       }
       serviceMethodName := request.Method
       dot := strings.LastIndexByte(serviceMethodName, '.')
       if dot < 0 {
          fmt.Printf("无效的方法名: %s\n", serviceMethodName)
       }
       serviceName := serviceMethodName[:dot]
       methodName := serviceMethodName[dot+1:]

       /* 旧代码把客户端发送的服务名,方法名,参数全部组装成消息返回回去
       // 约定好以\n作为消息结尾
       response := fmt.Sprintf("serviceName:%s, methodName:%s, request.Arg:%#v \n ", serviceName, methodName, request.Args)
       */
       svcInterface, ok := server.serviceMap.Load(serviceName)
       if !ok {
          fmt.Printf("未找到服务: %s\n", serviceName)
          return
       }
       svc, ok := svcInterface.(*service)
       if !ok {
          fmt.Printf("服务类型错误: %s\n", serviceName)
          return
       }
       mtype, ok := svc.method[methodName]
       if !ok {
          fmt.Printf("未找到方法: %s.%s\n", serviceName, methodName)
          return
       }
       // 根据方法的 ArgType 创建具体的 Args 实例
       argv := reflect.New(mtype.ArgType.Elem())

       // 将 request.Args 反序列化到 argv 中
       bs, err := json.Marshal(request.Args)
       if err != nil {
          fmt.Printf("序列化请求参数时出错: %v\n", err)
          return
       }
       err = json.Unmarshal(bs, argv.Interface())
       if err != nil {
          fmt.Printf("反序列化请求参数时出错: %v\n", err)
          return
       }

       replyv := reflect.New(mtype.ReplyType.Elem())
       results := mtype.method.Func.Call([]reflect.Value{svc.rcvr, argv, replyv})
       if len(results) > 0 && !results[0].IsNil() {
          fmt.Printf("调用方法出错: %v\n", results[0].Interface())
          return
       }
       reply := replyv.Elem().Interface()
       // 约定好以\n作为消息结尾
       response := fmt.Sprintf(" %v \n", reply)

       _, err = conn.Write([]byte(response))
       if err != nil {
          fmt.Printf("发送响应时出错: %v\n", err)
          return
       }
    }
}

func ServerConn(conn io.ReadWriteCloser) {
    DefaultServer.ServeConn(conn)
}

// =========服务注册代码=============//
type service struct {
    name   string                 // name of service
    rcvr   reflect.Value          // receiver of methods for the service
    typ    reflect.Type           // type of the receiver
    method map[string]*methodType // registered methods
}

type methodType struct {
    method    reflect.Method
    ArgType   reflect.Type
    ReplyType reflect.Type
}

func Register(rcvr any) {
    DefaultServer.register(rcvr)
}

func (server *Server) register(rcvr any) error {
    s := new(service)
    s.typ = reflect.TypeOf(rcvr)
    s.rcvr = reflect.ValueOf(rcvr)
    s.name = reflect.Indirect(s.rcvr).Type().Name()
    s.method = registerMethods(s.typ)
    server.serviceMap.Store(s.name, s)
    return nil
}

func registerMethods(typ reflect.Type) map[string]*methodType {
    methods := make(map[string]*methodType)
    // for循环遍历方法
    for i := 0; i < typ.NumMethod(); i++ {
       method := typ.Method(i)
       mtype := method.Type
       mname := method.Name
       /*
          对 net/rpc 而言,一个函数需要能够被远程调用,需要满足如下五个条件:
          1. the method’s type is exported. – 方法所属类型是导出的。
          2. the method is exported. – 方式是导出的。
          3. the method has two arguments, both exported (or builtin) types. – 两个入参,均为导出或内置类型。
          4. the method’s second argument is a pointer. – 第二个入参必须是一个指针。
          5. the method has return type error. – 返回值为 error 类型。
          此处沿用这个逻辑,只有满足这些条件的函数才可以被远程调用,记录到methods中
       */
       if method.IsExported() && mtype.NumIn() == 3 && mtype.NumOut() == 1 {
          argType := mtype.In(1) // 获取方法的第一个参数类型
          if argType.Kind() == reflect.Ptr {
             argType = argType.Elem() // 如果时指针类型,获取其指向的实际类型
          }
          if argType.Kind() != reflect.Struct {
             continue // 如果参数类型不是结构体,跳过该方法
          }

          methods[mname] = &methodType{
             method:    method,
             ArgType:   mtype.In(1), // 保持原始类型(可能是指针类型)
             ReplyType: mtype.In(2),
          }

       }
    }
    return methods
}
  • client.go(无修改)
package gopher_rpc

import (
    "bufio"
    "fmt"
    "net"
)

type Client struct {
    conn     net.Conn
    reader   *bufio.Reader
    done     chan struct{}
    errChan  chan error
    Response string
}

// Dail 用于建立与服务器的连接
func Dial(network, address string) (*Client, error) {
    conn, err := net.Dial(network, address)
    if err != nil {
       return nil, fmt.Errorf("无法连接到服务器 %s: %v", address, err)
    }
    client := &Client{
       conn:    conn,
       reader:  bufio.NewReader(conn),
       done:    make(chan struct{}), // 无缓冲channel
       errChan: make(chan error, 1),
    }
    // 开一个协程接受服务端返回
    go client.receive()
    return client, nil
}

func (c *Client) receive() {
    defer close(c.done)
    defer close(c.errChan)
    response, err := c.reader.ReadString('\n')
    if err != nil {
       c.errChan <- fmt.Errorf("接收服务器响应时出错:%v", err.Error())
       return
    }
    c.Response = response
    c.done <- struct{}{}
}

// Call 方法用于向服务端发送消息并等待响应(创建Client时开了个协程去接受服务端响应)
func (c *Client) Call(serviceMethod string) error {
    doneChan := c.Go(serviceMethod)
    select {
    case <-doneChan:
       return nil
    case err := <-c.errChan:
       return err
    }
}

func (c *Client) Go(serviceMethod string) chan struct{} {
    _, err := c.conn.Write([]byte(serviceMethod + "\n"))
    if err != nil {
       c.errChan <- fmt.Errorf("发送消息时出错: %v", err)
       close(c.done)
       close(c.errChan)
       return nil
    }
    return c.done
}

// Close 方法用于关闭客户端连接
func (c *Client) Close() error {
    if c.conn != nil {
       return c.conn.Close()
    }
    return nil
}

服务调用示例

示例一

定义运算服务ArithServer,并实现Add方法。

  • server_main.go
package main

import (
    "fmt"
    "gopher_rpc"
    "net"
)

type ArithServer struct{}

type Args struct {
    Num1 int `json:"num_1"`
    Num2 int `json:"num_2"`
}

func (s *ArithServer) Add(args *Args, reply *int) error {
    *reply = args.Num1 + args.Num2
    return nil
}

func main() {
    gopher_rpc.Register(new(ArithServer)) // 注册服务

    listener, err := net.Listen("tcp", ":8888")
    if err != nil {
       fmt.Printf("监听端口时出错:%v\n", err)
       return
    }
    defer listener.Close()
    fmt.Println("ArithServer is listening on port 8888....")
    for {
       conn, err := listener.Accept()
       if err != nil {
          fmt.Printf("接收连接时出错:%v\n", err)
          continue
       }
       // 开一个go协程,异步处理连接请求
       go gopher_rpc.ServerConn(conn)
    }
}
  • client_main.go
package main

import (
    "encoding/json"
    "fmt"
    "gopher_rpc"
)

type ServiceMethod struct {
    Method string `json:"method"`
    Args   Args   `json:"args"`
}

type Args struct {
    Num1 int `json:"num_1"`
    Num2 int `json:"num_2"`
}

func main() {
    // 连接到服务端
    client, err := gopher_rpc.Dial("tcp", "127.0.0.1:8888")
    if err != nil {
       fmt.Println(err)
       return
    }
    defer client.Close()

    param := &ServiceMethod{
       Method: "ArithServer.Add", // 调用ArithServer服务的Add方法
       Args: Args{
          Num1: 10,
          Num2: 20,
       },
    }
    bs, _ := json.Marshal(param)
    if err = client.Call(string(bs)); err != nil {
       fmt.Println(err)
       return
    }
    fmt.Printf("ArithServer.Add:  %d + %d = %s", param.Args.Num1, param.Args.Num2, client.Response)
}

服务端运行

image.png

客户端发起RPC调用,并打印执行结果

image.png

示例二

定义排序服务SortServer,并实现Sort方法。

  • server_main.go
package main

import (
    "fmt"
    "gopher_rpc"
    "net"
    "sort"
)

type SortServer struct{}

type Args struct {
    Nums []int `json:"nums"`
}

func (s *SortServer) Sort(args *Args, reply *[]int) error {
    sort.Ints(args.Nums)
    *reply = args.Nums
    return nil
}

func main() {
    gopher_rpc.Register(new(SortServer)) // 注册服务

    listener, err := net.Listen("tcp", ":8888")
    if err != nil {
       fmt.Printf("监听端口时出错:%v\n", err)
       return
    }
    defer listener.Close()
    fmt.Println("SortServer is listening on port 8888....")
    for {
       conn, err := listener.Accept()
       if err != nil {
          fmt.Printf("接收连接时出错:%v\n", err)
          continue
       }
       // 开一个go协程,异步处理连接请求
       go gopher_rpc.ServerConn(conn)
    }
}
  • client_main.go
package main

import (
    "encoding/json"
    "fmt"
    "gopher_rpc"
)

type ServiceMethod struct {
    Method string `json:"method"`
    Args   Args   `json:"args"`
}

type Args struct {
    Nums []int `json:"nums"`
}

func main() {
    // 连接到服务端
    client, err := gopher_rpc.Dial("tcp", "127.0.0.1:8888")
    if err != nil {
       fmt.Println(err)
       return
    }
    defer client.Close()

    param := &ServiceMethod{
       Method: "SortServer.Sort", // 调用SortServer服务的Sort方法
       Args: Args{
          Nums: []int{1, 5, 6, 4, 2},
       },
    }
    bs, _ := json.Marshal(param)
    if err = client.Call(string(bs)); err != nil {
       fmt.Println(err)
       return
    }
    fmt.Printf("排序后的结果: %s", client.Response)
}

服务端运行

image.png

客户端发起RPC调用,并打印执行结果

image.png

总结

通过这个系列,我们使用了200多行Go代码,实现了一个简易的RPC框架。当然了,这个简易的RPC框架,从严格意义上来说并不能称作为框架,其实现的服务功能非常脆弱,许多细节实现也比较简单粗暴。当然博主的本意也只是想去除net/rpc源码的细节,发掘RPC的本质,把最原始的RPC服务功能展现给大家。

希望这个系列文章,可以帮助到屏幕前的你更好地理解RPC是怎么工作的。如果对你有所帮助,欢迎收藏+关注,你的支持是我创作的最大动力!