简单RPC实践 | 青训营笔记

132 阅读4分钟

这是我参与「第五届青训营 」笔记创作活动的第12天

简单RPC实践

之前写过一篇RPC框架的学习笔记,这篇文章不再赘述概念,直接来看一个简单RPC的实现代码

源码

网络通信层

简单使用net包提供的tcp建立网络通信、 服务端:

给服务端规定一个地址端口

	addr := "localhost:3212"
	srv := server.NewServer(addr)

在初始化服务端的时候监听

	l, err := net.Listen("tcp", s.addr)
	if err != nil {
		log.Printf("listen on %s err: %v\n", s.addr, err)
		return
	}

客户端:

由于是简单rpc,客户端直接tcp连接,并在之后用这个连接与服务端通信


	// start client
	conn, err := net.Dial("tcp", addr)
	if err != nil {
		panic(err)
	}
	cli := client.NewClient(conn)

协议层

这里采用变长协议:以定长加不定长的部分组成,其中定长的部分需要描述不定长的内容长度

定义定长部分:

const headerLen = 4

写入:

	buf := make([]byte, headerLen+len(data))
	binary.BigEndian.PutUint32(buf[:headerLen], uint32(len(data)))
	copy(buf[headerLen:], data)

写入后调用网络通信层传输:

	_, err := t.conn.Write(buf)
	if err != nil {
		return err
	}

网络通信层接收数据:

	header := make([]byte, headerLen)
	_, err := io.ReadFull(t.conn, header)
	if err != nil {
		return nil, err
	}

读取:

	dataLen := binary.BigEndian.Uint32(header)
	data := make([]byte, dataLen)
	_, err = io.ReadFull(t.conn, data)
	if err != nil {
		return nil, err
	}

解编码层

这层主要是将真正的数据与二进制编码进行转换

真正的数据即我们定义的RPC数据包

// RPCdata represents the serializing format of structured data
type RPCdata struct {
	Name string        // name of the function
	Args []interface{} // request's or response's body expect error.
	Err  string        // Error any executing remote server
}

编码:

// Encode The RPCdata in binary format which can
// be sent over the network.
func Encode(data RPCdata) ([]byte, error) {
	var buf bytes.Buffer
	encoder := gob.NewEncoder(&buf)
	if err := encoder.Encode(data); err != nil {
		return nil, err
	}
	return buf.Bytes(), nil
}

解码:

// Encode The RPCdata in binary format which can
// be sent over the network.
func Encode(data RPCdata) ([]byte, error) {
	var buf bytes.Buffer
	encoder := gob.NewEncoder(&buf)
	if err := encoder.Encode(data); err != nil {
		return nil, err
	}
	return buf.Bytes(), nil
}

服务端

属性

服务端主要包含运行的地址端口以及其可被调用的方法

// RPCServer ...
type RPCServer struct {
	addr  string
	funcs map[string]reflect.Value
}

注册

我们要把这样一个方法注册进服务端

func QueryUser(id int) (User, error)

...

srv.Register("QueryUser", QueryUser)

也就是存入反射的值存入服务端方法字典

func (s *RPCServer) Register(fnName string, fFunc interface{}) {
	if _, ok := s.funcs[fnName]; ok {
		return
	}

	s.funcs[fnName] = reflect.ValueOf(fFunc)
}

运行

主要是死循环监听客户端发来的数据,如果得到数据,经过层层解码,得到我们定义的RPC数据包后调用执行方法,再把返回值层层编码发给客户端

	for {
		conn, err := l.Accept()
		if err != nil {
			log.Printf("accept err: %v\n", err)
			continue
		}
		go func() {
			connTransport := transport.NewTransport(conn)
			for {
				// read request
				req, err := connTransport.Read()
				if err != nil {
					if err != io.EOF {
						log.Printf("read err: %v\n", err)
						return
					}
				}

				// decode the data and pass it to execute
				decReq, err := dataserial.Decode(req)
				if err != nil {
					log.Printf("Error Decoding the Payload err: %v\n", err)
					return
				}
				// get the executed result.
				resP := s.Execute(decReq)
				// encode the data back
				b, err := dataserial.Encode(resP)
				if err != nil {
					log.Printf("Error Encoding the Payload for response err: %v\n", err)
					return
				}
				// send response to client
				err = connTransport.Send(b)
				if err != nil {
					log.Printf("transport write err: %v\n", err)
				}
			}
		}()
	}

执行

根据RPC数据包中的方法名,从字典中取出该方法反射的值,在调用反射的Call函数输入RPC数据包的参数,再把得到的结果和异常打包为数据包返回

// Execute the given function if present
func (s *RPCServer) Execute(req dataserial.RPCdata) dataserial.RPCdata {
	// get method by name
	f, ok := s.funcs[req.Name]
	if !ok {
		// since method is not present
		e := fmt.Sprintf("func %s not Registered", req.Name)
		log.Println(e)
		return dataserial.RPCdata{Name: req.Name, Args: nil, Err: e}
	}

	log.Printf("func %s is called\n", req.Name)
	// unpack request arguments
	inArgs := make([]reflect.Value, len(req.Args))
	for i := range req.Args {
		inArgs[i] = reflect.ValueOf(req.Args[i])
	}

	// invoke requested method
	out := f.Call(inArgs)
	// now since we have followed the function signature style where last argument will be an error
	// so we will pack the response arguments expect error.
	resArgs := make([]interface{}, len(out)-1)
	for i := 0; i < len(out)-1; i++ {
		// Interface returns the constant value stored in v as an interface{}.
		resArgs[i] = out[i].Interface()
	}

	// pack error argument
	var er string
	if _, ok := out[len(out)-1].Interface().(error); ok {
		// convert the error into error string value
		er = out[len(out)-1].Interface().(error).Error()
	}
	return dataserial.RPCdata{Name: req.Name, Args: resArgs, Err: er}
}

客户端

属性

主要保存着与服务端的连接

// Client struct
type Client struct {
	conn net.Conn
}

调用

把函数指针作为参数传入,之后调用这个函数即相当于远程调用,并且看起来和本地调用一样一样的

	var Query func(int) (User, error)
	cli.CallRPC("QueryUser", &Query)
    
    u, err := Query(1)
	if err != nil {
		panic(err)
	}

在CallRPC中,是将真正调用的f函数赋给了传入的函数指针

	container := reflect.ValueOf(fPtr).Elem()
    f := func(req []reflect.Value) []reflect.Value{
        ...
    }
    container.Set(reflect.MakeFunc(container.Type(), f))

在真正调用的f函数中,主要就是将接收的参数层层打包,发给服务端,再把接收到的回应层层解开,返回参数

	f := func(req []reflect.Value) []reflect.Value {
		cReqTransport := transport.NewTransport(c.conn)
		errorHandler := func(err error) []reflect.Value {
			outArgs := make([]reflect.Value, container.Type().NumOut())
			for i := 0; i < len(outArgs)-1; i++ {
				outArgs[i] = reflect.Zero(container.Type().Out(i))
			}
			outArgs[len(outArgs)-1] = reflect.ValueOf(&err).Elem()
			return outArgs
		}

		// Process input parameters
		inArgs := make([]interface{}, 0, len(req))
		for _, arg := range req {
			inArgs = append(inArgs, arg.Interface())
		}

		// ReqRPC
		reqRPC := dataserial.RPCdata{Name: rpcName, Args: inArgs}
		b, err := dataserial.Encode(reqRPC)
		if err != nil {
			panic(err)
		}
		err = cReqTransport.Send(b)
		if err != nil {
			return errorHandler(err)
		}
		// receive response from server
		rsp, err := cReqTransport.Read()
		if err != nil { // local network error or decode error
			return errorHandler(err)
		}
		rspDecode, _ := dataserial.Decode(rsp)
		if rspDecode.Err != "" { // remote server error
			return errorHandler(errors.New(rspDecode.Err))
		}

		if len(rspDecode.Args) == 0 {
			rspDecode.Args = make([]interface{}, container.Type().NumOut())
		}
		// unpack response arguments
		numOut := container.Type().NumOut()
		outArgs := make([]reflect.Value, numOut)
		for i := 0; i < numOut; i++ {
			if i != numOut-1 { // unpack arguments (except error)
				if rspDecode.Args[i] == nil { // if argument is nil (gob will ignore "Zero" in transmission), set "Zero" value
					outArgs[i] = reflect.Zero(container.Type().Out(i))
				} else {
					outArgs[i] = reflect.ValueOf(rspDecode.Args[i])
				}
			} else { // unpack error argument
				outArgs[i] = reflect.Zero(container.Type().Out(i))
			}
		}

		return outArgs
	}