用GO写一个RPC框架 s05 (客户端编写)

804 阅读1分钟

前言

前面几章我们完成了 服务端的编写 现在开始客户端编写

github.com/dollarkille…

Client

type Client struct {
	options *Options
}

func NewClient(discover discovery.Discovery, options ...Option) *Client {
	client := &Client{
		options: defaultOptions(),
	}

	client.options.Discovery = discover

	for _, fn := range options {
		fn(client.options)
	}

	return client
}

option

type Options struct {
	Discovery         discovery.Discovery                 // 服务发现插件
	loadBalancing     load_banlancing.LoadBalancing       // 负载均衡插件
	serializationType codes.SerializationType             // 序列化插件
	compressorType    codes.CompressorType                // 压缩插件

	pool         int                                      // 连接池大小
	cryptology   cryptology.Cryptology
	rsaPublicKey []byte
	writeTimeout time.Duration
	readTimeout  time.Duration
	heartBeat    time.Duration
	Trace        bool
	AUTH         string                                   // AUTH TOKEN
}

func defaultOptions() *Options {
	defaultPoolSize := runtime.NumCPU() * 4
	if defaultPoolSize < 20 {
		defaultPoolSize = 20
	}

	return &Options{
		pool:              defaultPoolSize,
		serializationType: codes.MsgPack,
		compressorType:    codes.Snappy,
		loadBalancing:     load_banlancing.NewPolling(),
		cryptology:        cryptology.AES,
		rsaPublicKey: []byte(`
-----BEGIN PUBLIC KEY-----
-----END PUBLIC KEY-----`),
		writeTimeout: time.Minute,
		readTimeout:  time.Minute * 3,
		heartBeat:    time.Minute,
		Trace:        false,
		AUTH:         "",
	}
}

具体每个链接

type Connect struct {
	Client     *Client
	pool       *connectPool
	close      chan struct{}
	serverName string
}

func (c *Client) NewConnect(serverName string) (conn *Connect, err error) {
	connect := &Connect{
		Client:     c,
		serverName: serverName,
		close:      make(chan struct{}),
	}

	connect.pool, err = initPool(connect)
	return connect, err
}

初始化连接池

func initPool(c *Connect) (*connectPool, error) {
	cp := &connectPool{
		connect: c,
		pool:    make(chan LightClient, c.Client.options.pool),
	}

	return cp, cp.initPool()
}

func (c *connectPool) initPool() error {
	hosts, err := c.connect.Client.options.Discovery.Discovery(c.connect.serverName) // 调用服务发现 查看  发现具体服务
	if err != nil {
		return err
	}

	if len(hosts) == 0 {
		return errors.New(fmt.Sprintf("%s server 404", c.connect.serverName))
	}

	c.connect.Client.options.loadBalancing.InitBalancing(hosts)  // 初始化 负载均衡插件

	// 初始化连接池
	for i := 0; i < c.connect.Client.options.pool; i++ {
		client, err := newBaseClient(c.connect.serverName, c.connect.Client.options)  // 建立链接
		if err != nil {
			return errors.WithStack(err)
		}
		c.pool <- client
	}

	return nil
}

// 连接池中获取一个链接
func (c *connectPool) Get(ctx context.Context) (LightClient, error) {
	select {
	case <-ctx.Done():
		return nil, errors.New("pool get timeout")
	case r := <-c.pool:
		return r, nil
	}
}

// 放回一个链接
func (c *connectPool) Put(client LightClient) {
	if client.Error() == nil {
		c.pool <- client
		return
	}

        // 如果 client.Error() 有异常  需要新初始化一个链接 放入连接池
	go func() {
		fmt.Println("The server starts to restore")
		for {
			time.Sleep(time.Second)
			hosts, err := c.connect.Client.options.Discovery.Discovery(c.connect.serverName)
			if err != nil {
				log.Println(err)
				continue
			}

			if len(hosts) == 0 {
				err := errors.New(fmt.Sprintf("%s server 404", c.connect.serverName))
				log.Println(err)
				continue
			}

			c.connect.Client.options.loadBalancing.InitBalancing(hosts)
			baseClient, err := newBaseClient(c.connect.serverName, c.connect.Client.options)
			if err != nil {
				log.Println(err)
				continue
			}

			c.pool <- baseClient
			fmt.Println("Service recovery success")
			break
		}
	}()
}

Connect 调用具体服务

func (c *Connect) Call(ctx *light.Context, serviceMethod string, request interface{}, response interface{}) error {
	ctxT, _ := context.WithTimeout(context.TODO(), time.Second*6)
	var err error
        
        // 连接池中获取一个链接
	client, err := c.pool.Get(ctxT)
	if err != nil {
		return errors.WithStack(err)
	}
        
        // 用完 放回链接
	defer func() {
		c.pool.Put(client)
	}()

        // 设置token
	ctx.SetValue("Light_AUTH", c.Client.options.AUTH)
	// 具体调用
        err = client.Call(ctx, serviceMethod, request, response)
	if err != nil {
		return errors.WithStack(err)
	}

	return nil
}

调用核心 重点

复习 s03 协议设计

/**
	协议设计
	起始符 :  版本号 :  crc32校验 :   magicNumberSize:    serverNameSize :   serverMethodSize :  metaDataSize : payloadSize:  respType :   compressorType :    serializationType :    magicNumber :  serverName :   serverMethod :  metaData :  payload
        0x05  :  0x01  :     4     :        4         :         4         :         4          :       4       :      4     :      1    :          1       :           1          :        xxx     :       xxx   :        xxx     :    xxx    :    xxx
*/

注意: 每一个请求都有一个 magicNumber 都有一个请求ID

单个链接定义

type BaseClient struct {
	conn       net.Conn
	options    *Options
	serverName string

	aesKey        []byte
	serialization codes.Serialization
	compressor    codes.Compressor

	respInterMap map[string]*respMessage
	respInterRM  sync.RWMutex     // 返回结构锁
	writeMu      sync.Mutex   // 写锁

	err   error          // 错误
	close chan struct{}  // 用于关闭服务
}

type respMessage struct {
	response interface{}
	ctx      *light.Context
	respChan chan error
}

初始化单个链接

func newBaseClient(serverName string, options *Options) (*BaseClient, error) {
        // 服务发现用
	service, err := options.loadBalancing.GetService()
	if err != nil {
		return nil, err
	}
	con, err := transport.Client.Gen(service.Protocol, service.Addr)
	if err != nil {
		return nil, errors.WithStack(err)
	}

	serialization, ex := codes.SerializationManager.Get(options.serializationType)
	if !ex {
		return nil, pkg.ErrSerialization404
	}

	compressor, ex := codes.CompressorManager.Get(options.compressorType)
	if !ex {
		return nil, pkg.ErrCompressor404
	}

	// 握手
	encrypt, err := cryptology.RsaEncrypt([]byte(options.AUTH), options.rsaPublicKey)
	if err != nil {
		return nil, err
	}

	aesKey := []byte(strings.ReplaceAll(uuid.New().String(), "-", ""))

	// 交换秘钥
	aesKey2, err := cryptology.RsaEncrypt(aesKey, options.rsaPublicKey)
	if err != nil {
		return nil, err
	}
	handshake := protocol.EncodeHandshake(aesKey2, encrypt, []byte(""))
	_, err = con.Write(handshake)
	if err != nil {
		con.Close()
		return nil, err
	}

	hsk := &protocol.Handshake{}
	err = hsk.Handshake(con)
	if err != nil {
		con.Close()
		return nil, err
	}
	if hsk.Error != nil && len(hsk.Error) > 0 {
		con.Close()
		err := string(hsk.Error)
		return nil, errors.New(err)
	}

	bc := &BaseClient{
		serverName:    serverName,
		conn:          con,
		options:       options,
		serialization: serialization,
		compressor:    compressor,
		respInterMap:  map[string]*respMessage{},
		aesKey:        aesKey,
		close:         make(chan struct{}),
	}

	go bc.heartBeat()  // 心跳服务
	go bc.processMessageManager()  // 返回消息的处理

	return bc, nil
}

heartBeat 心跳服务

func (b *BaseClient) heartBeat() {
	defer func() {
		fmt.Println("heartBeat Close")
	}()

loop:
	for {
		select {
		case <-b.close:
			break loop
		case <-time.After(b.options.heartBeat):  // 定时发送心跳
			_, i, err := protocol.EncodeMessage("x", []byte(""), []byte(""), []byte(""), byte(protocol.HeartBeat), byte(b.options.compressorType), byte(b.options.serializationType), []byte(""))
			if err != nil {
				log.Println(err)
				break
			}
			now := time.Now()
			b.conn.SetDeadline(now.Add(b.options.writeTimeout))
			b.conn.SetWriteDeadline(now.Add(b.options.writeTimeout))
			b.writeMu.Lock()
			_, err = b.conn.Write(i)
			b.writeMu.Unlock()
			if err != nil {
				b.err = err
				break loop
			}
		}
	}
}

processMessageManager 返回消息的处理服务 (注意这里可以并发的来)

func (b *BaseClient) processMessageManager() {
	defer func() {
		fmt.Println("processMessageManager Close")
	}()

	for {
		magic, respChan, err := b.processMessage() // 处理某个消息
		if err == nil && magic == "" {
			continue
		}

		if err != nil && magic == "" {
			break
		}

		if err != nil && magic != "" && respChan != nil {
			respChan <- err
		}

		if err == nil && magic != "" && respChan != nil {
			close(respChan)
		}
	}
}

func (b *BaseClient) processMessage() (magic string, respChan chan error, err error) {
	// 3.封装回执
	now := time.Now()
	b.conn.SetReadDeadline(now.Add(b.options.readTimeout))

	proto := protocol.NewProtocol()
	msg, err := proto.IODecode(b.conn)
	if err != nil {
		b.err = err
		close(b.close)
		return "", nil, err
	}

	// heartbeat
	if msg.Header.RespType == byte(protocol.HeartBeat) {
		if b.options.Trace {
			log.Println("is HeartBeat")
		}
		return "", nil, nil
	}

	b.respInterRM.RLock()
	message, ex := b.respInterMap[msg.MagicNumber]
	b.respInterRM.RUnlock()
	if !ex { // 不存在 代表消息已经失效
		if b.options.Trace {
			log.Println("Not Ex", msg.MagicNumber)
		}
		return "", nil, nil
	}

	comp, ex := codes.CompressorManager.Get(codes.CompressorType(msg.Header.CompressorType))
	if !ex {
		return "", nil, nil
	}

	// 1. 解压缩
	msg.MetaData, err = comp.Unzip(msg.MetaData)
	if err != nil {
		return "", nil, err
	}
	msg.Payload, err = comp.Unzip(msg.Payload)
	if err != nil {
		return "", nil, err
	}
	// 2. 解密
	msg.MetaData, err = cryptology.AESDecrypt(b.aesKey, msg.MetaData)
	if err != nil {
		if len(msg.MetaData) != 0 {
			return "", nil, err
		}
		msg.Payload = []byte("")
	}

	msg.Payload, err = cryptology.AESDecrypt(b.aesKey, msg.Payload)
	if err != nil {
		if len(msg.Payload) != 0 {
			return "", nil, err
		}
		msg.Payload = []byte("")
	}
	// 3. 反序列化 RespError
	mtData := make(map[string]string)
	err = b.serialization.Decode(msg.MetaData, &mtData)
	if err != nil {
		return "", nil, err
	}

	message.ctx.SetMetaData(mtData)

	value := message.ctx.Value("RespError")
	if value != "" {
		return msg.MagicNumber, message.respChan, errors.New(value)
	}

	return msg.MagicNumber, message.respChan, b.serialization.Decode(msg.Payload, message.response)
}

服务调用

func (b *BaseClient) call(ctx *light.Context, serviceMethod string, request interface{}, response interface{}, respChan chan error) (magic string, err error) {
	metaData := ctx.GetMetaData()  // 获取ctx 进行基础编码

	// 1. 构造请求
	// 1.1 序列化
	serviceNameByte := []byte(b.serverName)
	serviceMethodByte := []byte(serviceMethod)
	var metaDataBytes []byte
	var requestBytes []byte
	metaDataBytes, err = b.serialization.Encode(metaData)
	if err != nil {
		return "", err
	}
	requestBytes, err = b.serialization.Encode(request)
	if err != nil {
		return "", err
	}

	// 1.2 加密
	metaDataBytes, err = cryptology.AESEncrypt(b.aesKey, metaDataBytes)
	if err != nil {
		return "", err
	}

	requestBytes, err = cryptology.AESEncrypt(b.aesKey, requestBytes)
	if err != nil {
		return "", err
	}

	compressorType := b.options.compressorType
	if len(metaDataBytes) > compressorMin && len(metaDataBytes) < compressorMax {
		// 1.3 压缩
		metaDataBytes, err = b.compressor.Zip(metaDataBytes)
		if err != nil {
			return "", err
		}

		requestBytes, err = b.compressor.Zip(requestBytes)
		if err != nil {
			return "", err
		}
	} else {
		compressorType = codes.RawData
	}

	// 1.4 封装消息
	magic, message, err := protocol.EncodeMessage("", serviceNameByte, serviceMethodByte, metaDataBytes, byte(protocol.Request), byte(compressorType), byte(b.options.serializationType), requestBytes)
	if err != nil {
		return "", err
	}
	// 2. 发送消息
	if b.options.writeTimeout > 0 {
		now := time.Now()
		timeout := ctx.GetTimeout() // 如果ctx 存在设置 则采用 返之使用默认配置
		if timeout > 0 {
			b.conn.SetDeadline(now.Add(timeout))
			b.conn.SetWriteDeadline(now.Add(timeout))
		} else {
			b.conn.SetDeadline(now.Add(b.options.writeTimeout))
			b.conn.SetWriteDeadline(now.Add(b.options.writeTimeout))
		}
	}
	// 写MAP
	b.respInterRM.Lock()
	b.respInterMap[magic] = &respMessage{
		response: response,
		ctx:      ctx,
		respChan: respChan,
	}
	b.respInterRM.Unlock()

	// 有点暴力呀 直接上锁
	b.writeMu.Lock()
	_, err = b.conn.Write(message)
	b.writeMu.Unlock()
	if err != nil {
		if b.options.Trace {
			log.Println(err)
		}
		b.err = err
		return "", errors.WithStack(err)
	}

	return magic, nil
}

专栏地址: juejin.cn/column/6986…