用GO写一个RPC框架 s04 (编写服务端核心)

616 阅读3分钟

前言

通过上两篇的学习 我们已经了解了 服务端本地服务的注册, 服务端配置,协议 现在我们开始写服务端的核心逻辑

github.com/dollarkille…

默认配置

我们先看下默认的配置

func defaultOptions() *Options {
	return &Options{
		Protocol:     transport.TCP, // default TCP
		Uri:          "0.0.0.0:8397",
		UseHttp:      false,
		readTimeout:  time.Minute * 3, // 心跳包 默认 3min
		writeTimeout: time.Second * 30,
		ctx:          context.Background(), // ctx 是控制服务退出的
		options: map[string]interface{}{
			"TCPKeepAlivePeriod": time.Minute * 3,
		},
		processChanSize: 1000,    
		Trace:           false,
		RSAPublicKey: []byte(`-----BEGIN PUBLIC KEY-----
-----END PUBLIC KEY-----`),
		RSAPrivateKey: []byte(`-----BEGIN RSA PRIVATE KEY-----
-----END RSA PRIVATE KEY-----`),
		Discovery: &discovery.SimplePeerToPeer{},
	}
}

run

服务注册完毕之后 调用Run方法 启动服务

func (s *Server) Run(options ...Option) error {
        // 初始化 服务端配置
	for _, fn := range options {
		fn(s.options)
	}

	var err error
        // 更具配置传入的protocol 获取到 网络插件 (KCP UDP TCP) 我们等下细讲
	s.options.nl, err = transport.Transport.Gen(s.options.Protocol, s.options.Uri)
	if err != nil {
		return err
	}

	log.Printf("LightRPC: %s  %s \n", s.options.Protocol, s.options.Uri)

        // 这里是服务注册 我们这里先跳过  
	if s.options.Discovery != nil {
                // 读取服务配置文件
		sIdb, err := ioutil.ReadFile("./light.conf")
		if err != nil {
                        // 如果没有 就生成 分布式ID
			id, err := utils.DistributedID()
			if err != nil {
				return err
			}
			sIdb = []byte(id)
		}
		// 进行服务注册
		sId := string(sIdb)
		for k := range s.serviceMap {   // 进行服务注册 
			err := s.options.Discovery.Registry(k, s.options.registryAddr, s.options.weights, s.options.Protocol, s.options.MaximumLoad, &sId)
			if err != nil {
				return err
			}
			log.Printf("Discovery Registry: %s addr: %s SUCCESS", k, s.options.registryAddr)
		}

		ioutil.WriteFile("./light.conf", sIdb, 00666)
	}
        
        // 启动服务
	return s.run()
}



func (s *Server) run() error {
loop:
	for {
		select {
		case <-s.options.ctx.Done():  // 检查是否需要退出服务
			break loop
		default:
			accept, err := s.options.nl.Accept() // 获取一个链接
			if err != nil {
				log.Println(err)
				continue
			}
			if s.options.Trace {
				log.Println("connect: ", accept.RemoteAddr())
			}

			go s.process(accept) // 开一个协程去处理 该 链接
		}

	}

	return nil
}

我们先回顾一下 上章讲的 握手逻辑

  1. 建立链接 通过非对称加密 传输 aes 密钥给服务端 (携带token)
  2. 服务端 验证 token 并记录 aes 密钥 后面与客户端交互 都采用对称加密

具体处理 链接 process (重点!!!)

func (s *Server) process(conn net.Conn) {

	defer func() {
		// 网络不可靠
		if err := recover(); err != nil {
			utils.PrintStack()
			log.Println("Recover Err: ", err)
		}
	}()

        // 每进来一个请求这里就ADD
	s.options.Discovery.Add(1)
	defer func() {
		s.options.Discovery.Less(1) // 处理完 请求就退出
		// 退出 回收句柄
		err := conn.Close()  
		if err != nil {
			log.Println(err)
			return
		}

		if s.options.Trace {
			log.Println("close connect: ", conn.RemoteAddr())
		}
	}()

        // 这里定义一个xChannel 用于分离 请求和返回
	xChannel := utils.NewXChannel(s.options.processChanSize)

	// 握手
	handshake := protocol.Handshake{}
	err := handshake.Handshake(conn)
	if err != nil {
		return
	}
            
        // 非对称加密  解密 AES KEY
	aesKey, err := cryptology.RsaDecrypt(handshake.Key, s.options.RSAPrivateKey)
	if err != nil {
		encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(err.Error()))
		conn.Write(encodeHandshake)
		return
	}

        // 检测 AES KEY 是否正确
	if len(aesKey) != 32 && len(aesKey) != 16 {
		encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte("aes key != 32 && key != 16"))
		conn.Write(encodeHandshake)
		return
	}
        
        // 解密 TOKEN
	token, err := cryptology.RsaDecrypt(handshake.Token, s.options.RSAPrivateKey)
	if err != nil {
		encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(err.Error()))
		conn.Write(encodeHandshake)
		return
	}
        // 对TOKEN进行校验  
	if s.options.AuthFunc != nil {
		err := s.options.AuthFunc(light.DefaultCtx(), string(token))
		if err != nil {
			encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(err.Error()))
			conn.Write(encodeHandshake)
			return
		}
	}

	// limit 限流
	if s.options.Discovery.Limit() {
		// 熔断
		encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(pkg.ErrCircuitBreaker.Error()))
		conn.Write(encodeHandshake)
		log.Println(s.options.Discovery.Limit())
		return
	}
        
        // 如果握手没有问题 则返回握手成功
	encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(""))
	_, err = conn.Write(encodeHandshake)
	if err != nil {
		return
	}
        
	// send
	go func() {
	loop:
		for {
			select {
                        // 这就是刚刚的xChannel 对读写进行分离
			case msg, ex := <-xChannel.Ch: 
				if !ex {
					if s.options.Trace {
						log.Printf("ip: %s  close send server", conn.RemoteAddr())
					}
					break loop
				}
				now := time.Now()
				if s.options.writeTimeout > 0 {
					conn.SetWriteDeadline(now.Add(s.options.writeTimeout))
				}
				// send message
				_, err := conn.Write(msg)
				if err != nil {
					if s.options.Trace {
						log.Printf("ip: %s err: %s", conn.RemoteAddr(), err)
					}
					break loop
				}
			}
		}
	}()

	defer func() {
		xChannel.Close()
	}()
loop:
	for { // 具体消息获取
		now := time.Now()
		if s.options.readTimeout > 0 {
			conn.SetReadDeadline(now.Add(s.options.readTimeout))
		}

		proto := protocol.NewProtocol()
		msg, err := proto.IODecode(conn) // 获取一个消息
		if err != nil {
			if err == io.EOF {
				if s.options.Trace {
					log.Printf("ip: %s close", conn.RemoteAddr())
				}
				break loop
			}

			// 遇到错误关闭链接
			if s.options.Trace {
				log.Printf("ip: %s err: %s", conn.RemoteAddr(), err)
			}
			break loop
		}

		go s.processResponse(xChannel, msg, conn.RemoteAddr().String(), aesKey)
	}
}

具体处理 (重点!!!)

注意此RPC传输消息都是编码过的 要进行转码

  • 第一层 为压缩编码
  • 第二层 为加密编码
  • 第三层 为序列化
func (s *Server) processResponse(xChannel *utils.XChannel, msg *protocol.Message, addr string, aesKey []byte) {
	var err error
	s.options.Discovery.Add(1)
	defer func() {
		s.options.Discovery.Less(1)
		if err != nil {
			if s.options.Trace {
				log.Println("ProcessResponse Error: ", err, "  ID: ", addr)
			}
			xChannel.Close()
		}
	}()

	// heartBeat 判断
	if msg.Header.RespType == byte(protocol.HeartBeat) {
		// 心跳返回
		if s.options.Trace {
			log.Println("HeartBeat: ", addr)
		}

		// 4. 打包
		_, message, err := protocol.EncodeMessage(msg.MagicNumber, []byte(msg.ServiceName), []byte(msg.ServiceMethod), []byte(""), byte(protocol.HeartBeat), msg.Header.CompressorType, msg.Header.SerializationType, []byte(""))
		if err != nil {
			return
		}
		// 5. 回写
		err = xChannel.Send(message)
		if err != nil {
			return
		}

		return
	}

	// 限流
	if s.options.Discovery.Limit() {
		serialization, _ := codes.SerializationManager.Get(codes.MsgPack)
		metaData := make(map[string]string)
		metaData["RespError"] = pkg.ErrCircuitBreaker.Error()
		meta, err := serialization.Encode(metaData)
		if err != nil {
			return
		}
		decrypt, err := cryptology.AESDecrypt(aesKey, meta)
		if err != nil {
			return
		}
		_, message, err := protocol.EncodeMessage(msg.MagicNumber, []byte(msg.ServiceName), []byte(msg.ServiceMethod), decrypt, byte(protocol.Response), byte(codes.RawData), byte(codes.MsgPack), []byte(""))
		if err != nil {
			return
		}
		// 5. 回写
		err = xChannel.Send(message)
		if err != nil {
			return
		}

		log.Println(s.options.Discovery.Limit())
		log.Println("限流/////////////")

		return
	}

	// 1. 解压缩
	compressor, ex := codes.CompressorManager.Get(codes.CompressorType(msg.Header.CompressorType))
	if !ex {
		err = errors.New("compressor 404")
		return
	}
	msg.MetaData, err = compressor.Unzip(msg.MetaData)
	if err != nil {
		return
	}

	msg.Payload, err = compressor.Unzip(msg.Payload)
	if err != nil {
		return
	}
	// 2. 解密
	msg.MetaData, err = cryptology.AESDecrypt(aesKey, msg.MetaData)
	if err != nil {
		return
	}

	msg.Payload, err = cryptology.AESDecrypt(aesKey, msg.Payload)
	if err != nil {
		return
	}

	// 3. 反序列化
	serialization, ex := codes.SerializationManager.Get(codes.SerializationType(msg.Header.SerializationType))
	if !ex {
		err = errors.New("serialization 404")
		return
	}

	metaData := make(map[string]string)
	err = serialization.Decode(msg.MetaData, &metaData)
	if err != nil {
		return
	}

        // 初始化context
	ctx := light.DefaultCtx()
	ctx.SetMetaData(metaData)

	// 1.3 auth
	if s.options.AuthFunc != nil {
		auth := metaData["Light_AUTH"]
		err := s.options.AuthFunc(ctx, auth)
		if err != nil {
			ctx.SetValue("RespError", err.Error())
			var metaDataByte []byte
			metaDataByte, _ = serialization.Encode(ctx.GetMetaData())
			metaDataByte, _ = cryptology.AESEncrypt(aesKey, metaDataByte)
			metaDataByte, _ = compressor.Zip(metaDataByte)
			// 4. 打包
			_, message, err := protocol.EncodeMessage(msg.MagicNumber, []byte(msg.ServiceName), []byte(msg.ServiceMethod), metaDataByte, byte(protocol.Response), msg.Header.CompressorType, msg.Header.SerializationType, []byte(""))
			if err != nil {
				return
			}
			// 5. 回写
			err = xChannel.Send(message)
			if err != nil {
				return
			}
			return
		}
	}

        // 找到具体调用的服务
	ser, ex := s.serviceMap[msg.ServiceName]
	if !ex {
		err = errors.New("service does not exist")
		return
	}

        // 找到具体调用的方法
	method, ex := ser.methodType[msg.ServiceMethod]
	if !ex {
		err = errors.New("method does not exist")
		return
	}

        // 初始化 req, resp
	req := utils.RefNew(method.RequestType)
	resp := utils.RefNew(method.ResponseType)

	err = serialization.Decode(msg.Payload, req)
	if err != nil {
		return
	}

        // 定义ctx paht 为   服务名称.服务方法
	path := fmt.Sprintf("%s.%s", msg.ServiceName, msg.ServiceMethod)
	ctx.SetPath(path)

	// 前置middleware
	if len(s.beforeMiddleware) != 0 {
		for idx := range s.beforeMiddleware {
			err := s.beforeMiddleware[idx](ctx, req, resp)
			if err != nil {
				return
			}
		}
	}
	funcs, ex := s.beforeMiddlewarePath[path]
	if ex {
		if len(funcs) != 0 {
			for idx := range funcs {
				err := funcs[idx](ctx, req, resp)
				if err != nil {
					return
				}
			}
		}
	}

	// 核心调用
	callErr := ser.call(ctx, method, reflect.ValueOf(req), reflect.ValueOf(resp))
	if callErr != nil {
		ctx.SetValue("RespError", callErr.Error())
	}

	// 后置middleware
	if len(s.afterMiddleware) != 0 {
		for idx := range s.afterMiddleware {
			err := s.afterMiddleware[idx](ctx, req, resp)
			if err != nil {
				return
			}
		}
	}
	funcs, ex = s.afterMiddlewarePath[path]
	if ex {
		if len(funcs) != 0 {
			for idx := range funcs {
				err := funcs[idx](ctx, req, resp)
				if err != nil {
					return
				}
			}
		}
	}
	// response

	// 1. 序列化
	var respBody []byte
	respBody, err = serialization.Encode(resp)

	var metaDataByte []byte
	metaDataByte, _ = serialization.Encode(ctx.GetMetaData())
	// 2. 加密
	metaDataByte, err = cryptology.AESEncrypt(aesKey, metaDataByte)
	if err != nil {
		return
	}
	respBody, err = cryptology.AESEncrypt(aesKey, respBody)
	if err != nil {
		return
	}
	// 3. 压缩
	metaDataByte, err = compressor.Zip(metaDataByte)
	if err != nil {
		return
	}
	respBody, err = compressor.Zip(respBody)
	if err != nil {
		return
	}
	// 4. 打包
	_, message, err := protocol.EncodeMessage(msg.MagicNumber, []byte(msg.ServiceName), []byte(msg.ServiceMethod), metaDataByte, byte(protocol.Response), msg.Header.CompressorType, msg.Header.SerializationType, respBody)
	if err != nil {
		return
	}
	// 5. 回写
	err = xChannel.Send(message)
	if err != nil {
		return
	}
}

调用具体方法

func (s *service) call(ctx *light.Context, mType *methodType, request, response reflect.Value) (err error) {
	// recover 捕获堆栈消息
	defer func() {
		if r := recover(); r != nil {
			buf := make([]byte, 4096)
			n := runtime.Stack(buf, false)
			buf = buf[:n]

			err = fmt.Errorf("[painc service internal error]: %v, method: %s, argv: %+v, stack: %s",
				r, mType.method.Name, request.Interface(), buf)
			log.Println(err)
		}
	}()

	fn := mType.method.Func
	returnValue := fn.Call([]reflect.Value{s.refVal, reflect.ValueOf(ctx), request, response})
	errInterface := returnValue[0].Interface()
	if errInterface != nil {
		return errInterface.(error)
	}

	return nil
}

这里就完成了服务端的基础逻辑了

专栏: juejin.cn/column/6986…