重构即时IM系统8:StateServer(上)

29 阅读12分钟

在网关层我们通过Epoll实现了对于用户长连接FD的监听,大大减少了协程数的消耗,还通过gateWay将消息转发给StateServer进行处理,这样网关层就只负责维护长连接FD和用户ID与长连接FD的映射关系,这部分业务逻辑十分稳定,几乎不需要变更,所以网关层不会因为业务的变更而频繁重启,导致网关上的所有长连接被断开。

StateServer职责

我们将业务部分转移到了StateServer这个服务中,网关负责监听用户FD,当缓冲区中有数据时就会唤醒Epoll然后把缓冲区中的数据按照websocket的格式解析Header和Payload,然后直接将Payload通过grpc转发给StateServer进行处理,这样gateWay只负责透传,StateServer负责逻辑处理的架构。我们需要定义二者之间的通信协议,也就是Protobuf定义。

StateServer需要实现一个协议解析模块,这个模块可以解析gateWay发送过来的消息是什么类型,然后根据这个类型进行处理逻辑的路由,这样的话二者的消息结构就清晰了。gateWay解析ws的Header,根据Header的长度字段从缓冲区中读取Payload,然后通过grpc将这个Payload转发给StateServer,这个Payload本身就是由msgType和msg组成的,这部分由StateServer的协议解析模块进行维护,最后路由到业务层进行处理。

由此可见,StateServer的职责主要是协议解析与分发,之前还说过StateServer还需要维护哪个用户在哪台网关机器上,这样才可以进行业务消息的转发,所以StateServer还维护用户在线状态(User A -> Gateway 1)。

此外,StateServer还有一个至关重要的职责就是维护心跳,在长连接架构中心跳是必不可少的。用户在使用APP时会经常遇到断网(手动关闭,进入电梯,隧道)的情况,在这些场景中,客户端根本没有办法发送TCP的FIN包给服务端,根据 TCP 协议特性,如果没有数据传输,服务端会一直认为这个连接是健康的。这就导致服务器上会堆积成千上万个无效的“僵尸连接”,白白占用着内存和文件描述符。StateServer通过心跳来实施健康检测,如果在一定事件内没有收到用户的心跳,StateServr就可以断定该用户以及掉线,从而主动断开连接,清理用户相关部分。如果不即时清理的话,如果有用户给离线用户发送消息,StateServer以为离线用户还在线,就会给对应用户发送消息,然后超时,再次发送浪费网络带宽。

协议部分

之前我们定义了state.proto文件,这个文件主要描述了gateWay和StateServer之间的交互协议,主要字段有UserID,识别是哪个用户发送的,payload就是我们说的负载,这个是通过解析websocket的负载来的,还有gateway_id用于告知StateServer这个用户在哪台网关机器上,便于回调。

然而state.proto定义的只是gateWay和StateServer之间的通信流程,其中关键的payload字段目前还是一串未知的二进制数据,为了让StateServer能读懂这串数据,我们需要定义第二层协议,也就是端到端的业务消息协议protocal.proto

设计这个消息协议的时候我采用了两层嵌套的方式,如下:

// CommandType 定义消息类型
enum CommandType {
    UNKNOWN = 0;
    LOGIN = 1;      // 登录
    HEARTBEAT = 2;  // 心跳
    MESSAGE = 3;    // 聊天消息 (信令)
}

// Command 是最外层的封包结构 (Envelope)
message Command {
    CommandType type = 1; // 消息类型
    bytes data = 2;       // 具体的业务消息 Payload (序列化后的二进制)
}

// LoginCommand 登录消息
message LoginCommand {
    string token = 1;    // 鉴权 Token
    bytes extra = 2;     // 扩展字段
}

// HeartbeatCommand 心跳消息
message HeartbeatCommand {
    int64 timestamp = 1; // 客户端时间戳
}

// MessageCommand 聊天消息 (信令)
message MessageCommand {
    string uuid = 1;        // 消息唯一ID
    int32 type = 2;         // 消息类型 (文本/图片等)
    string content = 3;     // 消息内容
    string receiver_id = 4; // 接收者ID (群聊则是群ID)
    int32 conversation_type = 5; // 会话类型 (1=单聊, 2=群聊)
    bytes extra = 6;        // 扩展字段
}

外层包装是Command,这是所有业务消息的统一入口,其中type字段表示这是一个什么消息,data字段就是具体的业务数据,根据type的不同反序列化为不同的内层结构体,这种设计让StateServer可以先解析外层,拿到消息类型,再通过路由分发去解析内层,实现了极佳的扩展性。

内层消息就是后面三个具体的结构体,这里我先实现这三个,有需要的话再在后面补充。内层消息就是客户端实际发送的消息,不过在发送时被打包成了Command,序列化之后通过WebSocket发送给GateWay,GateWay将其作为Payload封装进之前说过的state.proto中的请求体中,透传给StateServer;最终由StateServer层层拆包,还原出原始的业务意图。

StateServer改造

之前我们的StateServer实现非常暴力,就是直接拿gateWay给你的payload去构造responsePayload := []byte(fmt.Sprintf("StateServer Echo: %s", string(req.Payload)))然后把这个返回给gateWay,但是现在就不是简单的回显逻辑了,需要理解payload并且将其分发给正确的执行单元。

在gateWay发来的请求中,payload是一串未知的二进制,StateServer首先会根据protocol.Command进行反序列化,因为客户端发送时将请求包装为了Command,这一步可以拿到消息类型和消息体。如果序列化失败的话说明数据包格式错误,这里直接丢弃。如果成功的话我们就可以拿到消息的Type,此时可以执行switch逻辑去路由到对应的处理函数中,如果未来要增加业务的话也是在这里修改。

进入具体的处理函数之后,例如handleLogin,我们已经明确知道Command.Data里装的一定是 LoginCommand 的二进制数据,此时进行第二次反序列化,还原出具体的业务对象,然后执行真正的业务逻辑,具体代码如下:

func (s *Service) ReceiveMessage(ctx context.Context, req *pb.ReceiveMessageRequest) (*pb.ReceiveMessageResponse, error) {
    // 反序列化外层 Command
    var cmd protocol.Command
    if err := proto.Unmarshal(req.Payload, &cmd); err != nil {
        log.Printf("[StateServer] Unmarshal Command error: %v", err)
        return nil, err
    }

    log.Printf("[StateServer] Received Command: Type=%v, Gateway=%s, User=%s", cmd.Type, req.GatewayId, req.Uid)

    // 路由分发
    switch cmd.Type {
        case protocol.CommandType_LOGIN:
        return s.handleLogin(ctx, &cmd, req)
        case protocol.CommandType_HEARTBEAT:
        return s.handleHeartbeat(ctx, &cmd, req)
        case protocol.CommandType_MESSAGE:
        return s.handleMessage(ctx, &cmd, req)
        default:
        log.Printf("[StateServer] Unknown command type: %v", cmd.Type)
        return nil, fmt.Errorf("unknown command type")
    }
}

至于handleLogin这种处理函数我们后面说到业务逻辑的时候再讲。

客户端适配

之前的测试客户端发送的是纯文本字符串,现在必须发送proto协议序列化之后的二进制流,首先根据用户需求创建一个具体的业务对象,然后将其序列化,这是内层业务对象的初始化。然后再创建一个Command对象,将刚才的二进制流填入Data字段,然后根据业务类型设置Type,将这个 Command 对象再次序列化,通过 WebSocket 的 BinaryMessage 模式发送出去,下面是客户端关键代码:


func sendLogin(c *websocket.Conn) {
    loginPayload := &protocol.LoginCommand{
        Token: "test-token-123",
    }
    data, _ := proto.Marshal(loginPayload)

    cmd := &protocol.Command{
        Type: protocol.CommandType_LOGIN,
        Data: data,
    }

    sendProto(c, cmd)
}

func sendHeartbeat(c *websocket.Conn) {
    hbPayload := &protocol.HeartbeatCommand{
        Timestamp: time.Now().Unix(),
    }
    data, _ := proto.Marshal(hbPayload)

    cmd := &protocol.Command{
        Type: protocol.CommandType_HEARTBEAT,
        Data: data,
    }
    sendProto(c, cmd)
}
// ......


func sendProto(c *websocket.Conn, cmd *protocol.Command) {
    bytes, err := proto.Marshal(cmd)
    if err != nil {
        log.Println("marshal error:", err)
        return
    }

    err = c.WriteMessage(websocket.BinaryMessage, bytes)
    if err != nil {
        log.Println("write error:", err)
    } else {
        log.Printf("sent command: type=%v", cmd.Type)
    }
}

登录实现

客户端给gateWay发送携带token的Login数据包,网关直接把payload通过grpc转发给StateServer,StateServer发现是Login类型,就会路由到handleLogin方法。在方法内通过proto解析出对应业务相关结构体,通过Auth服务或者查库验证token的合法性,验证成功之后就拿到UserID,然后把这个UserID通过grpc返回给gateWay,这相当于StateServer对GateWay下达指令,把刚才发送消息的那个连接绑定对应的UserID,便于路由。

上面说的返回用户ID给gateWay是为了让gateWay知道这个FD对应的是哪个用户,为了实现转发逻辑,我们还需要知道这个用户连接的是哪台网关机器,所以我们需要在StateServer的处理函数中将用户ID和网关机器ID作为KV注册到redis中。

下面是相关变更代码:

// state.proto
message ReceiveMessageResponse {
    bytes response_payload = 1;
    // 新增两个字段
    string bound_user_id = 2;
    bool disconnect = 3;
}

其中bound_user_id表示这个连接现在绑定哪一个用户ID,disconnect字段是StateServer通知gateWay断开连接的,比如发现token无效,黑名单用户,异地登录互踢这种。gateWay收到响应后,发现disconnect为真的话就会关闭连接,不再推送返回的消息体。(或者你可以选择先把payload推送完再关,目前的逻辑是直接关)。

StateServer的注册处理函数:

func (s *Service) handleLogin(ctx context.Context, cmd *protocol.Command, req *pb.ReceiveMessageRequest) (*pb.ReceiveMessageResponse, error) {
    var loginCmd protocol.LoginCommand
    if err := proto.Unmarshal(cmd.Data, &loginCmd); err != nil {
        return nil, err
    }

    log.Printf("[StateServer] Handle Login: Token=%s", loginCmd.Token)
    // TODO: 调用 Auth 服务验证 Token,绑定 UserID
    // 这里简单模拟成功
    userID := "u_" + loginCmd.Token // Simple mock

    key := fmt.Sprintf("UserID:%s", userID)
    if err := s.rdb.Set(ctx, key, req.GatewayId, 5*time.Minute).Err(); err != nil {
        log.Printf("[StateServer] Failed to save session: %v", err)
        return nil, err
    }

    return &pb.ReceiveMessageResponse{
        ResponsePayload: []byte("Login Success"),
        BoundUserId:     userID,
        Disconnect:      false,
    }, nil
}

网关层新增代码:

// Check if StateServer authorized a user
if resp.BoundUserId != "" {
    log.Printf("Client %d bound to user %s", c.FD, resp.BoundUserId)
    c.Uid = resp.BoundUserId
    s.BindUser(resp.BoundUserId, c)
}

// Check if StateServer requested disconnect
if resp.Disconnect {
    log.Printf("Client %d disconnected by StateServer", c.FD)
    s.closeClient(c, ep)
    return
}

心跳实现

客户端给GateWay发送携带心跳数据的Heartbeat数据包,网关直接把 payload 通过 grpc 转发给 StateServer(同时网关会刷新本地连接的超时时间以防误踢)。StateServer路由到对应处理函数,在函数内,StateServer 会根据请求中携带的 UserID,刷新 Redis 中该用户 ID 与网关机器 ID 映射关系的过期时间(TTL),实现 Session 的保活。最后,StateServer给gateWay返回ack,gateWay再将其通过webSocket推送给客户端,让客户端确认连接依然存活。

为了实现ack,我们在Command结构体中新增了codemsg字段,code如果为0就表示响应成功,而且code的错误码可以留在以后做拓展,msg字段表示提示消息,可能是错误消息,可能是执行成功的说明。

心跳定时器是由客户端维护的,因为由gateWay或者StateServer为每个客户端维护一个定时器开销太大,所以选择让客户端维护自己的心跳定时器,发送请求等待响应。

下面看gateWay中更新的代码:

func (s *Server) keepAliveLoop(ctx context.Context) {
    // Check every 10 seconds
    ticker := time.NewTicker(10 * time.Second)
    defer ticker.Stop()

    for {
        select {
            case <-ctx.Done():
            return
            case <-ticker.C:
            now := time.Now().Unix()

            for _, ep := range s.Epolls {
                var timeoutClients []*Client

                ep.lock.RLock()
                for _, c := range ep.connections {
                    if now-c.lastHeartbeat > 300 { // 300s timeout
                        timeoutClients = append(timeoutClients, c)
                    }
                }
                ep.lock.RUnlock()

                // Close timeout clients
                for _, c := range timeoutClients {
                    log.Printf("Client %d timed out (last active: %ds ago)", c.FD, now-c.lastHeartbeat)
                    s.closeClient(c, ep)
                }
            }
        }
    }
}

这个函数在gateWay调用Start启动的时候就开启一个协程执行keepAliveLoop,核心职责是定期踢掉那些占用连接而且没响应的僵尸用户。主要就是遍历所有注册到epoll中的连接,获取连接的客户端,检查现在时间距离该客户端上次收到心跳的时间是否超过5分钟,超过的话就把这个连接计入待踢出队列,后面统一close。

这里之所以使用每10s检查一次是因为这个对时间精度的容忍度高,心跳主要是为了释放长时间无响应的资源所以晚一两秒其实没有什么区别,如果追求精度的话可以为每个client维护一个定时器,但是这样开销太大了,优化的话可以采用时间轮算法,在精度和性能角度做取舍。

由gateWay透传给StateServer之后,StateServer识别到这是一个心跳,路由到心跳处理函数:

func (s *Service) handleHeartbeat(ctx context.Context, cmd *protocol.Command, req *pb.ReceiveMessageRequest) (*pb.ReceiveMessageResponse, error) {
    if req.Uid != "" {
        key := fmt.Sprintf("session:%s", req.Uid)
        if err := s.rdb.Expire(ctx, key, 10*time.Minute).Err(); err != nil {
            log.Printf("[StateServer] RenewSession failed: %v", err)
        }
    }

    // 构造响应
    respCmd := &protocol.Command{
        Type: protocol.CommandType_HEARTBEAT,
        Code: 0,
    }
    respBytes, _ := proto.Marshal(respCmd)

    return &pb.ReceiveMessageResponse{
        ResponsePayload: respBytes,
    }, nil
}

刷新用户和网关的map关系,然后构造回复返回给客户端,之后客户端需要解析请求才知道这是一个心跳的ack,然后进行相应处理。

测试

测试结果和代码如下:

image-20260124110430503


func main() {
	u := url.URL{Scheme: "ws", Host: "localhost:8002", Path: "/"}
	log.Printf("Connecting to %s", u.String())

	// 建立 WebSocket 连接
	c, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
	if err != nil {
		log.Fatal("dial:", err)
	}
	defer c.Close()

	done := make(chan struct{})

	// 启动接收协程 (Receive Loop)
	go func() {
		defer close(done)
		for {
			_, message, err := c.ReadMessage()
			if err != nil {
				log.Println("read:", err)
				return
			}
			// 解析 Command
			var cmd protocol.Command
			if err := proto.Unmarshal(message, &cmd); err != nil {
				log.Printf("unmarshal error: %v", err)
				continue
			}

			if cmd.Code != 0 {
				log.Printf("recv Error: Code=%d, Msg=%s", cmd.Code, cmd.Msg)
				continue
			}

			switch cmd.Type {
			case protocol.CommandType_LOGIN:
				log.Printf("Login Success!")
			case protocol.CommandType_HEARTBEAT:
				log.Printf("recv Pong")
			case protocol.CommandType_MESSAGE:
				// 如果 Data 不为空,说明可能是一个推送下来的新消息,而不仅仅是 ACK
				if len(cmd.Data) > 0 {
					var msgCmd protocol.MessageCommand
					if err := proto.Unmarshal(cmd.Data, &msgCmd); err == nil {
						log.Printf("recv Message: From=%s, Content=%s", "System", msgCmd.Content)
					} else {
						log.Printf("recv Message ACK (Data len=%d)", len(cmd.Data))
					}
				} else {
					log.Printf("recv Message ACK: %s", cmd.Msg)
				}
			}
		}
	}()

	// 发送消息测试 (Send Loop)
	ticker := time.NewTicker(2 * time.Second)
	defer ticker.Stop()

	interrupt := make(chan os.Signal, 1)
	signal.Notify(interrupt, os.Interrupt)

	// 先发送登录包
	sendLogin(c)

	count := 0
	for {
		select {
		case <-done:
			return
		case <-ticker.C:
			// 轮流发送心跳和聊天消息
			count++
			if count%2 == 0 {
				sendHeartbeat(c)
			} else {
				sendMessage(c)
			}

		case <-interrupt:
			// 优雅退出
			log.Println("interrupt")
			err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
			if err != nil {
				log.Println("write close:", err)
				return
			}
			select {
			case <-done:
			case <-time.After(time.Second):
			}
			return
		}
	}
}

func sendLogin(c *websocket.Conn) {
	// 构造 LoginCommand
	loginPayload := &protocol.LoginCommand{
		Token: "test-token-123",
	}
	data, _ := proto.Marshal(loginPayload)

	// 构造外层 Command
	cmd := &protocol.Command{
		Type: protocol.CommandType_LOGIN,
		Data: data,
	}

	sendProto(c, cmd)
}

func sendHeartbeat(c *websocket.Conn) {
	hbPayload := &protocol.HeartbeatCommand{
		Timestamp: time.Now().Unix(),
	}
	data, _ := proto.Marshal(hbPayload)

	cmd := &protocol.Command{
		Type: protocol.CommandType_HEARTBEAT,
		Data: data,
	}
	sendProto(c, cmd)
}

func sendMessage(c *websocket.Conn) {
	msgPayload := &protocol.MessageCommand{
		Uuid:       "msg-" + time.Now().String(),
		Type:       1,
		Content:    "Hello StateServer! " + time.Now().Format(time.TimeOnly),
		ReceiverId: "1002",
	}
	data, _ := proto.Marshal(msgPayload)

	cmd := &protocol.Command{
		Type: protocol.CommandType_MESSAGE,
		Data: data,
	}
	sendProto(c, cmd)
}

func sendProto(c *websocket.Conn, cmd *protocol.Command) {
	// 序列化
	bytes, err := proto.Marshal(cmd)
	if err != nil {
		log.Println("marshal error:", err)
		return
	}

	// 发送二进制帧 (BinaryMessage)
	err = c.WriteMessage(websocket.BinaryMessage, bytes)
	if err != nil {
		log.Println("write error:", err)
	} else {
		log.Printf("sent command: type=%v", cmd.Type)
	}
}