59-消息推送系统详解

0 阅读15分钟

消息推送系统详解

本章导读

消息推送系统是现代互联网应用的核心基础设施,负责将消息实时、可靠地送达给用户。本章将深入剖析长连接管理、消息路由策略、送达保证机制、海量推送优化等核心问题,帮助你掌握IM系统和推送通知系统的设计精髓。

学习目标

  • 目标1:掌握WebSocket长连接管理与消息协议设计
  • 目标2:理解消息路由、ACK机制、离线存储的核心原理
  • 目标3:能够设计高并发、高可用的消息推送服务

前置知识:熟悉Netty框架,了解TCP/WebSocket协议,掌握Redis基础

阅读时长:约 55 分钟

一、知识概述

消息推送系统是现代互联网应用的核心基础设施之一,负责将消息实时、可靠地送达给用户。无论是即时通讯、社交动态、系统通知,还是营销推送,都离不开消息推送系统的支持。

本文将深入剖析消息推送系统的设计要点,包括长连接管理、消息路由策略、送达保证机制、海量推送优化等核心问题,并提供完整的Java实现方案。

消息推送系统的核心挑战

  1. 实时性:消息需要毫秒级送达
  2. 可靠性:确保消息不丢失
  3. 高并发:支持海量用户同时在线
  4. 多端同步:手机、Web、桌面多端消息同步
  5. 离线推送:用户离线时的消息送达

二、知识点详细讲解

2.1 长连接管理

WebSocket连接管理
package com.example.push.connection;

import io.netty.channel.Channel;
import io.netty.channel.ChannelId;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.util.concurrent.GlobalEventExecutor;
import org.springframework.stereotype.Component;

import java.util.concurrent.ConcurrentHashMap;

/**
 * WebSocket连接管理器
 * 管理所有客户端的长连接
 */
@Component
public class ConnectionManager {
    
    // 所有连接的Channel组
    private final ChannelGroup channels = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
    
    // 用户ID -> Channel映射(支持多设备登录)
    private final ConcurrentHashMap<Long, ConcurrentHashMap<String, Channel>> userChannels = 
        new ConcurrentHashMap<>();
    
    // ChannelId -> 用户信息映射
    private final ConcurrentHashMap<ChannelId, UserInfo> channelUsers = 
        new ConcurrentHashMap<>();
    
    /**
     * 添加连接
     * @param userId 用户ID
     * @param deviceId 设备ID
     * @param channel WebSocket Channel
     */
    public void addConnection(Long userId, String deviceId, Channel channel) {
        // 添加到全局Channel组
        channels.add(channel);
        
        // 用户Channel映射
        userChannels.computeIfAbsent(userId, k -> new ConcurrentHashMap<>())
                    .put(deviceId, channel);
        
        // Channel到用户的映射
        channelUsers.put(channel.id(), new UserInfo(userId, deviceId));
        
        System.out.println(String.format(
            "用户上线: userId=%d, deviceId=%s, channelId=%s, 当前在线用户数=%d",
            userId, deviceId, channel.id(), userChannels.size()
        ));
    }
    
    /**
     * 移除连接
     * @param channel WebSocket Channel
     */
    public void removeConnection(Channel channel) {
        UserInfo userInfo = channelUsers.remove(channel.id());
        if (userInfo != null) {
            Long userId = userInfo.getUserId();
            String deviceId = userInfo.getDeviceId();
            
            ConcurrentHashMap<String, Channel> deviceMap = userChannels.get(userId);
            if (deviceMap != null) {
                deviceMap.remove(deviceId);
                if (deviceMap.isEmpty()) {
                    userChannels.remove(userId);
                }
            }
            
            System.out.println(String.format(
                "用户下线: userId=%d, deviceId=%s, 当前在线用户数=%d",
                userId, deviceId, userChannels.size()
            ));
        }
        
        channels.remove(channel);
    }
    
    /**
     * 获取用户的所有连接(多设备)
     */
    public List<Channel> getUserChannels(Long userId) {
        ConcurrentHashMap<String, Channel> deviceMap = userChannels.get(userId);
        if (deviceMap == null) {
            return Collections.emptyList();
        }
        return new ArrayList<>(deviceMap.values());
    }
    
    /**
     * 获取用户的指定设备连接
     */
    public Channel getUserChannel(Long userId, String deviceId) {
        ConcurrentHashMap<String, Channel> deviceMap = userChannels.get(userId);
        return deviceMap == null ? null : deviceMap.get(deviceId);
    }
    
    /**
     * 判断用户是否在线
     */
    public boolean isOnline(Long userId) {
        ConcurrentHashMap<String, Channel> deviceMap = userChannels.get(userId);
        return deviceMap != null && !deviceMap.isEmpty();
    }
    
    /**
     * 获取在线用户数
     */
    public int getOnlineUserCount() {
        return userChannels.size();
    }
    
    /**
     * 获取总连接数
     */
    public int getTotalConnectionCount() {
        return channels.size();
    }
    
    /**
     * 获取所有连接(用于广播)
     */
    public ChannelGroup getAllChannels() {
        return channels;
    }
    
    /**
     * 用户信息内部类
     */
    @lombok.Data
    @lombok.AllArgsConstructor
    private static class UserInfo {
        private Long userId;
        private String deviceId;
    }
}
Netty WebSocket服务器
package com.example.push.server;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.handler.timeout.IdleStateHandler;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import java.util.concurrent.TimeUnit;

/**
 * WebSocket服务器
 */
@Component
public class WebSocketServer {
    
    @Value("${websocket.port:8088}")
    private int port;
    
    private EventLoopGroup bossGroup;
    private EventLoopGroup workerGroup;
    private Channel serverChannel;
    
    @PostConstruct
    public void start() throws InterruptedException {
        bossGroup = new NioEventLoopGroup(1);
        workerGroup = new NioEventLoopGroup();
        
        ServerBootstrap bootstrap = new ServerBootstrap();
        bootstrap.group(bossGroup, workerGroup)
                .channel(NioServerSocketChannel.class)
                .option(ChannelOption.SO_BACKLOG, 1024)
                .childOption(ChannelOption.SO_KEEPALIVE, true)
                .childOption(ChannelOption.TCP_NODELAY, true)
                .childHandler(new ChannelInitializer<SocketChannel>() {
                    @Override
                    protected void initChannel(SocketChannel ch) {
                        ChannelPipeline pipeline = ch.pipeline();
                        
                        // HTTP编解码器
                        pipeline.addLast("http-codec", new HttpServerCodec());
                        
                        // 分块写入处理器
                        pipeline.addLast("http-chunked", new ChunkedWriteHandler());
                        
                        // HTTP消息聚合器
                        pipeline.addLast("aggregator", new HttpObjectAggregator(65536));
                        
                        // 心跳检测
                        pipeline.addLast("idle-state", new IdleStateHandler(
                            60, 0, 0, TimeUnit.SECONDS));
                        
                        // WebSocket协议处理器
                        pipeline.addLast("webSocket-protocol", 
                            new WebSocketServerProtocolHandler("/ws"));
                        
                        // 自定义消息处理器
                        pipeline.addLast("message-handler", new WebSocketMessageHandler());
                    }
                });
        
        // 绑定端口
        serverChannel = bootstrap.bind(port).sync().channel();
        System.out.println("WebSocket服务器启动成功,端口: " + port);
    }
    
    @PreDestroy
    public void stop() {
        if (serverChannel != null) {
            serverChannel.close();
        }
        if (bossGroup != null) {
            bossGroup.shutdownGracefully();
        }
        if (workerGroup != null) {
            workerGroup.shutdownGracefully();
        }
        System.out.println("WebSocket服务器已关闭");
    }
}
消息处理器
package com.example.push.server;

import com.example.push.connection.ConnectionManager;
import com.example.push.protocol.Message;
import com.example.push.protocol.MessageType;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.timeout.IdleStateEvent;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

/**
 * WebSocket消息处理器
 */
@Slf4j
@ChannelHandler.Sharable
@Component
public class WebSocketMessageHandler extends SimpleChannelInboundHandler<TextWebSocketFrame> {
    
    @Autowired
    private ConnectionManager connectionManager;
    
    @Autowired
    private MessageDispatcher messageDispatcher;
    
    private final ObjectMapper objectMapper = new ObjectMapper();
    
    /**
     * 连接建立
     */
    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
        if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
            // WebSocket握手完成
            log.info("WebSocket连接建立: {}", ctx.channel().id());
            
        } else if (evt instanceof IdleStateEvent) {
            // 心跳超时
            log.warn("心跳超时,关闭连接: {}", ctx.channel().id());
            ctx.close();
        }
    }
    
    /**
     * 接收消息
     */
    @Override
    protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame frame) throws Exception {
        String text = frame.text();
        log.debug("收到消息: {}", text);
        
        try {
            Message message = objectMapper.readValue(text, Message.class);
            
            switch (message.getType()) {
                case LOGIN:
                    handleLogin(ctx, message);
                    break;
                case HEARTBEAT:
                    handleHeartbeat(ctx, message);
                    break;
                case CHAT:
                    handleChat(ctx, message);
                    break;
                case ACK:
                    handleAck(ctx, message);
                    break;
                default:
                    log.warn("未知消息类型: {}", message.getType());
            }
            
        } catch (Exception e) {
            log.error("消息处理异常", e);
            sendError(ctx, "消息格式错误");
        }
    }
    
    /**
     * 处理登录
     */
    private void handleLogin(ChannelHandlerContext ctx, Message message) {
        Long userId = message.getFromUserId();
        String deviceId = message.getDeviceId();
        
        // 添加连接映射
        connectionManager.addConnection(userId, deviceId, ctx.channel());
        
        // 发送登录成功响应
        Message response = new Message();
        response.setType(MessageType.LOGIN_ACK);
        response.setContent("登录成功");
        sendMessage(ctx, response);
        
        // 推送离线消息
        messageDispatcher.pushOfflineMessages(userId, deviceId);
    }
    
    /**
     * 处理心跳
     */
    private void handleHeartbeat(ChannelHandlerContext ctx, Message message) {
        Message response = new Message();
        response.setType(MessageType.HEARTBEAT_ACK);
        sendMessage(ctx, response);
    }
    
    /**
     * 处理聊天消息
     */
    private void handleChat(ChannelHandlerContext ctx, Message message) {
        // 分发消息
        messageDispatcher.dispatch(message);
    }
    
    /**
     * 处理ACK确认
     */
    private void handleAck(ChannelHandlerContext ctx, Message message) {
        Long messageId = message.getMessageId();
        messageDispatcher.handleAck(messageId);
    }
    
    /**
     * 连接断开
     */
    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        connectionManager.removeConnection(ctx.channel());
        super.channelInactive(ctx);
    }
    
    /**
     * 异常处理
     */
    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        log.error("连接异常", cause);
        ctx.close();
    }
    
    /**
     * 发送消息
     */
    private void sendMessage(ChannelHandlerContext ctx, Message message) {
        try {
            String json = objectMapper.writeValueAsString(message);
            ctx.writeAndFlush(new TextWebSocketFrame(json));
        } catch (Exception e) {
            log.error("发送消息失败", e);
        }
    }
    
    /**
     * 发送错误消息
     */
    private void sendError(ChannelHandlerContext ctx, String error) {
        Message message = new Message();
        message.setType(MessageType.ERROR);
        message.setContent(error);
        sendMessage(ctx, message);
    }
}

2.2 消息协议设计

package com.example.push.protocol;

import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.AllArgsConstructor;
import java.time.LocalDateTime;

/**
 * 统一消息协议
 */
@Data
@NoArgsConstructor
@AllArgsConstructor
public class Message {
    
    /**
     * 消息ID(全局唯一)
     */
    private Long messageId;
    
    /**
     * 消息类型
     */
    private MessageType type;
    
    /**
     * 发送方用户ID
     */
    private Long fromUserId;
    
    /**
     * 接收方用户ID(私聊)
     */
    private Long toUserId;
    
    /**
     * 群组ID(群聊)
     */
    private Long groupId;
    
    /**
     * 设备ID
     */
    private String deviceId;
    
    /**
     * 消息内容
     */
    private String content;
    
    /**
     * 扩展字段(JSON格式)
     */
    private String extra;
    
    /**
     * 消息时间戳
     */
    private LocalDateTime timestamp;
    
    /**
     * 消息状态
     */
    private MessageStatus status;
    
    /**
     * 创建聊天消息
     */
    public static Message createChatMessage(Long fromUserId, Long toUserId, String content) {
        Message message = new Message();
        message.setMessageId(SnowflakeIdGenerator.nextId());
        message.setType(MessageType.CHAT);
        message.setFromUserId(fromUserId);
        message.setToUserId(toUserId);
        message.setContent(content);
        message.setTimestamp(LocalDateTime.now());
        message.setStatus(MessageStatus.SENDING);
        return message;
    }
    
    /**
     * 创建ACK确认消息
     */
    public static Message createAck(Long messageId) {
        Message message = new Message();
        message.setType(MessageType.ACK);
        message.setMessageId(messageId);
        message.setTimestamp(LocalDateTime.now());
        return message;
    }
}

/**
 * 消息类型枚举
 */
public enum MessageType {
    LOGIN,          // 登录
    LOGIN_ACK,      // 登录响应
    HEARTBEAT,      // 心跳
    HEARTBEAT_ACK,  // 心跳响应
    CHAT,           // 聊天消息
    GROUP_CHAT,     // 群聊消息
    SYSTEM,         // 系统消息
    ACK,            // 确认消息
    ERROR           // 错误消息
}

/**
 * 消息状态枚举
 */
public enum MessageStatus {
    SENDING,        // 发送中
    DELIVERED,      // 已送达
    READ,           // 已读
    FAILED          // 发送失败
}

/**
 * 雪花ID生成器
 */
public class SnowflakeIdGenerator {
    
    private static final long twepoch = 1704038400000L;
    private static final long sequenceBits = 13L;
    private static final long sequenceMask = ~(-1L << sequenceBits);
    
    private static long sequence = 0L;
    private static long lastTimestamp = -1L;
    
    public static synchronized long nextId() {
        long timestamp = System.currentTimeMillis();
        
        if (timestamp < lastTimestamp) {
            throw new RuntimeException("Clock moved backwards");
        }
        
        if (timestamp == lastTimestamp) {
            sequence = (sequence + 1) & sequenceMask;
            if (sequence == 0) {
                timestamp = tilNextMillis(lastTimestamp);
            }
        } else {
            sequence = 0L;
        }
        
        lastTimestamp = timestamp;
        
        return ((timestamp - twepoch) << sequenceBits) | sequence;
    }
    
    private static long tilNextMillis(long lastTimestamp) {
        long timestamp = System.currentTimeMillis();
        while (timestamp <= lastTimestamp) {
            timestamp = System.currentTimeMillis();
        }
        return timestamp;
    }
}

2.3 消息路由策略

package com.example.push.router;

import com.example.push.connection.ConnectionManager;
import com.example.push.protocol.Message;
import com.example.push.protocol.MessageType;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.netty.channel.Channel;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;

import java.util.List;
import java.util.concurrent.TimeUnit;

/**
 * 消息分发器
 */
@Slf4j
@Component
public class MessageDispatcher {
    
    @Autowired
    private ConnectionManager connectionManager;
    
    @Autowired
    private StringRedisTemplate redisTemplate;
    
    @Autowired
    private OfflineMessageStore offlineMessageStore;
    
    private final ObjectMapper objectMapper = new ObjectMapper();
    
    // Redis消息队列key
    private static final String MSG_QUEUE_PREFIX = "msg:queue:";
    
    // 消息ACK等待队列
    private static final String ACK_WAITING_QUEUE = "msg:ack:waiting";
    
    /**
     * 分发消息
     */
    public void dispatch(Message message) {
        switch (message.getType()) {
            case CHAT:
                dispatchPrivateMessage(message);
                break;
            case GROUP_CHAT:
                dispatchGroupMessage(message);
                break;
            case SYSTEM:
                dispatchSystemMessage(message);
                break;
            default:
                log.warn("无法分发的消息类型: {}", message.getType());
        }
    }
    
    /**
     * 分发私聊消息
     */
    private void dispatchPrivateMessage(Message message) {
        Long toUserId = message.getToUserId();
        
        // 1. 保存消息(用于多端同步和离线推送)
        saveMessage(message);
        
        // 2. 判断接收方是否在线
        if (connectionManager.isOnline(toUserId)) {
            // 在线推送
            List<Channel> channels = connectionManager.getUserChannels(toUserId);
            
            for (Channel channel : channels) {
                if (channel.isActive()) {
                    sendToChannel(channel, message);
                    
                    // 加入ACK等待队列
                    addAckWaiting(message.getMessageId(), toUserId);
                }
            }
            
            // 3. 推送通知到其他服务器(集群场景)
            publishToCluster(message);
            
        } else {
            // 离线存储
            offlineMessageStore.storeOfflineMessage(toUserId, message);
            
            // 触发离线推送(APNs、FCM等)
            triggerOfflinePush(toUserId, message);
        }
    }
    
    /**
     * 分发群聊消息
     */
    private void dispatchGroupMessage(Message message) {
        Long groupId = message.getGroupId();
        
        // 1. 获取群成员列表
        List<Long> memberIds = getGroupMembers(groupId);
        
        // 2. 保存消息
        saveMessage(message);
        
        // 3. 批量推送给在线成员
        for (Long memberId : memberIds) {
            if (!memberId.equals(message.getFromUserId())) {
                if (connectionManager.isOnline(memberId)) {
                    List<Channel> channels = connectionManager.getUserChannels(memberId);
                    channels.forEach(channel -> sendToChannel(channel, message));
                } else {
                    offlineMessageStore.storeOfflineMessage(memberId, message);
                }
            }
        }
    }
    
    /**
     * 分发系统消息
     */
    private void dispatchSystemMessage(Message message) {
        // 全员广播
        connectionManager.getAllChannels().forEach(channel -> {
            sendToChannel(channel, message);
        });
    }
    
    /**
     * 推送离线消息
     */
    public void pushOfflineMessages(Long userId, String deviceId) {
        List<Message> offlineMessages = offlineMessageStore.getOfflineMessages(userId);
        
        for (Message message : offlineMessages) {
            Channel channel = connectionManager.getUserChannel(userId, deviceId);
            if (channel != null && channel.isActive()) {
                sendToChannel(channel, message);
            }
        }
        
        log.info("推送离线消息: userId={}, count={}", userId, offlineMessages.size());
    }
    
    /**
     * 处理ACK确认
     */
    public void handleAck(Long messageId) {
        // 从等待队列移除
        String key = ACK_WAITING_QUEUE + ":" + messageId;
        redisTemplate.delete(key);
        
        // 更新消息状态
        updateMessageStatus(messageId, "DELIVERED");
        
        log.debug("消息ACK确认: messageId={}", messageId);
    }
    
    /**
     * 发送到Channel
     */
    private void sendToChannel(Channel channel, Message message) {
        try {
            String json = objectMapper.writeValueAsString(message);
            channel.writeAndFlush(new TextWebSocketFrame(json));
        } catch (Exception e) {
            log.error("发送消息失败: channelId={}", channel.id(), e);
        }
    }
    
    /**
     * 保存消息到数据库
     */
    private void saveMessage(Message message) {
        String key = "msg:store:" + message.getMessageId();
        try {
            String json = objectMapper.writeValueAsString(message);
            redisTemplate.opsForValue().set(key, json, 7, TimeUnit.DAYS);
        } catch (Exception e) {
            log.error("保存消息失败", e);
        }
    }
    
    /**
     * 添加到ACK等待队列
     */
    private void addAckWaiting(Long messageId, Long userId) {
        String key = ACK_WAITING_QUEUE + ":" + messageId;
        redisTemplate.opsForValue().set(key, userId.toString(), 1, TimeUnit.MINUTES);
    }
    
    /**
     * 发布到集群
     */
    private void publishToCluster(Message message) {
        try {
            String json = objectMapper.writeValueAsString(message);
            redisTemplate.convertAndSend("push:channel", json);
        } catch (Exception e) {
            log.error("发布到集群失败", e);
        }
    }
    
    /**
     * 获取群成员列表
     */
    private List<Long> getGroupMembers(Long groupId) {
        // 实际实现从数据库或缓存获取
        return List.of(1001L, 1002L, 1003L);
    }
    
    /**
     * 更新消息状态
     */
    private void updateMessageStatus(Long messageId, String status) {
        // 实际实现更新数据库
    }
    
    /**
     * 触发离线推送
     */
    private void triggerOfflinePush(Long userId, Message message) {
        // 调用APNs、FCM等推送服务
        log.info("触发离线推送: userId={}, messageId={}", userId, message.getMessageId());
    }
}

2.4 消息可靠性保证

消息ACK机制
package com.example.push.reliability;

import com.example.push.protocol.Message;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 消息可靠性管理器
 * 实现消息的重传和确认机制
 */
@Slf4j
@Component
public class MessageReliabilityManager {
    
    // 等待ACK的消息队列
    private final Map<Long, PendingMessage> pendingMessages = new ConcurrentHashMap<>();
    
    // 最大重试次数
    private static final int MAX_RETRY_COUNT = 3;
    
    // 重试间隔(毫秒)
    private static final long RETRY_INTERVAL_MS = 5000;
    
    /**
     * 添加待确认消息
     */
    public void addPendingMessage(Message message, Long userId) {
        PendingMessage pending = new PendingMessage();
        pending.setMessage(message);
        pending.setUserId(userId);
        pending.setSendTime(System.currentTimeMillis());
        pending.setRetryCount(0);
        
        pendingMessages.put(message.getMessageId(), pending);
    }
    
    /**
     * 确认消息送达
     */
    public void acknowledgeMessage(Long messageId) {
        PendingMessage pending = pendingMessages.remove(messageId);
        if (pending != null) {
            log.debug("消息已确认: messageId={}, 耗时={}ms", 
                messageId, System.currentTimeMillis() - pending.getSendTime());
        }
    }
    
    /**
     * 定时重传未确认消息
     */
    @Scheduled(fixedDelay = 5000)
    public void retryPendingMessages() {
        long now = System.currentTimeMillis();
        
        pendingMessages.forEach((messageId, pending) -> {
            if (now - pending.getSendTime() >= RETRY_INTERVAL_MS) {
                
                if (pending.getRetryCount() >= MAX_RETRY_COUNT) {
                    // 超过最大重试次数,移除并记录
                    pendingMessages.remove(messageId);
                    log.warn("消息重试失败,已丢弃: messageId={}, retryCount={}", 
                        messageId, pending.getRetryCount());
                    return;
                }
                
                // 重传消息
                retryMessage(pending);
                
                // 更新重试次数和时间
                pending.setRetryCount(pending.getRetryCount() + 1);
                pending.setSendTime(now);
            }
        });
    }
    
    /**
     * 重传消息
     */
    private void retryMessage(PendingMessage pending) {
        // 实际实现:重新发送消息
        log.info("重传消息: messageId={}, retryCount={}", 
            pending.getMessage().getMessageId(), pending.getRetryCount());
    }
    
    /**
     * 获取等待确认的消息数量
     */
    public int getPendingCount() {
        return pendingMessages.size();
    }
    
    /**
     * 待确认消息实体
     */
    @lombok.Data
    private static class PendingMessage {
        private Message message;
        private Long userId;
        private long sendTime;
        private int retryCount;
    }
}
离线消息存储
package com.example.push.store;

import com.example.push.protocol.Message;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;

/**
 * 离线消息存储器
 */
@Slf4j
@Component
public class OfflineMessageStore {
    
    @Autowired
    private StringRedisTemplate redisTemplate;
    
    private final ObjectMapper objectMapper = new ObjectMapper();
    
    // 离线消息队列前缀
    private static final String OFFLINE_QUEUE_PREFIX = "offline:msg:";
    
    // 单个用户最大离线消息数
    private static final int MAX_OFFLINE_MESSAGES = 1000;
    
    // 离线消息过期时间(天)
    private static final long OFFLINE_EXPIRE_DAYS = 7;
    
    /**
     * 存储离线消息
     */
    public void storeOfflineMessage(Long userId, Message message) {
        String key = OFFLINE_QUEUE_PREFIX + userId;
        
        try {
            String json = objectMapper.writeValueAsString(message);
            
            // 添加到有序集合(按时间戳排序)
            double score = message.getTimestamp().atZone(java.time.ZoneId.systemDefault())
                .toInstant().toEpochMilli();
            
            redisTemplate.opsForZSet().add(key, json, score);
            
            // 限制离线消息数量
            Long size = redisTemplate.opsForZSet().zCard(key);
            if (size != null && size > MAX_OFFLINE_MESSAGES) {
                redisTemplate.opsForZSet().removeRange(key, 0, size - MAX_OFFLINE_MESSAGES - 1);
            }
            
            // 设置过期时间
            redisTemplate.expire(key, OFFLINE_EXPIRE_DAYS, TimeUnit.DAYS);
            
            log.debug("存储离线消息: userId={}, messageId={}", userId, message.getMessageId());
            
        } catch (Exception e) {
            log.error("存储离线消息失败: userId={}", userId, e);
        }
    }
    
    /**
     * 获取离线消息
     */
    public List<Message> getOfflineMessages(Long userId) {
        String key = OFFLINE_QUEUE_PREFIX + userId;
        List<Message> messages = new ArrayList<>();
        
        try {
            Set<String> jsonSet = redisTemplate.opsForZSet().range(key, 0, -1);
            
            if (jsonSet != null) {
                for (String json : jsonSet) {
                    Message message = objectMapper.readValue(json, Message.class);
                    messages.add(message);
                }
            }
            
        } catch (Exception e) {
            log.error("获取离线消息失败: userId={}", userId, e);
        }
        
        return messages;
    }
    
    /**
     * 删除已读的离线消息
     */
    public void removeOfflineMessages(Long userId, List<Long> messageIds) {
        String key = OFFLINE_QUEUE_PREFIX + userId;
        
        try {
            Set<String> jsonSet = redisTemplate.opsForZSet().range(key, 0, -1);
            
            if (jsonSet != null) {
                for (String json : jsonSet) {
                    Message message = objectMapper.readValue(json, Message.class);
                    if (messageIds.contains(message.getMessageId())) {
                        redisTemplate.opsForZSet().remove(key, json);
                    }
                }
            }
            
        } catch (Exception e) {
            log.error("删除离线消息失败: userId={}", userId, e);
        }
    }
    
    /**
     * 清空离线消息
     */
    public void clearOfflineMessages(Long userId) {
        String key = OFFLINE_QUEUE_PREFIX + userId;
        redisTemplate.delete(key);
    }
    
    /**
     * 获取离线消息数量
     */
    public long getOfflineMessageCount(Long userId) {
        String key = OFFLINE_QUEUE_PREFIX + userId;
        Long count = redisTemplate.opsForZSet().zCard(key);
        return count == null ? 0 : count;
    }
}

2.5 海量推送优化

批量推送优化
package com.example.push.batch;

import com.example.push.connection.ConnectionManager;
import com.example.push.protocol.Message;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.netty.channel.Channel;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.util.List;
import java.util.concurrent.*;

/**
 * 批量推送管理器
 * 优化大规模推送性能
 */
@Slf4j
@Component
public class BatchPushManager {
    
    @Autowired
    private ConnectionManager connectionManager;
    
    private final ObjectMapper objectMapper = new ObjectMapper();
    
    // 推送线程池
    private final ExecutorService pushExecutor = Executors.newFixedThreadPool(
        Runtime.getRuntime().availableProcessors() * 2,
        new ThreadFactoryBuilder().setNameFormat("push-pool-%d").build()
    );
    
    // 批量推送队列
    private final BlockingQueue<PushTask> pushQueue = new LinkedBlockingQueue<>(100000);
    
    // 批量大小
    private static final int BATCH_SIZE = 100;
    
    /**
     * 添加推送任务
     */
    public void addPushTask(List<Long> userIds, Message message) {
        PushTask task = new PushTask();
        task.setUserIds(userIds);
        task.setMessage(message);
        task.setCreateTime(System.currentTimeMillis());
        
        if (!pushQueue.offer(task)) {
            log.warn("推送队列已满,丢弃任务");
        }
    }
    
    /**
     * 批量推送线程
     */
    public void startBatchPush() {
        new Thread(() -> {
            while (true) {
                try {
                    List<PushTask> batch = new ArrayList<>(BATCH_SIZE);
                    
                    // 从队列中取出批量任务
                    PushTask task = pushQueue.poll(100, TimeUnit.MILLISECONDS);
                    if (task != null) {
                        batch.add(task);
                        pushQueue.drainTo(batch, BATCH_SIZE - 1);
                    }
                    
                    if (!batch.isEmpty()) {
                        processBatch(batch);
                    }
                    
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    break;
                } catch (Exception e) {
                    log.error("批量推送异常", e);
                }
            }
        }, "batch-push-thread").start();
    }
    
    /**
     * 处理批量推送
     */
    private void processBatch(List<PushTask> batch) {
        CountDownLatch latch = new CountDownLatch(batch.size());
        
        for (PushTask task : batch) {
            pushExecutor.submit(() -> {
                try {
                    pushToUsers(task.getUserIds(), task.getMessage());
                } catch (Exception e) {
                    log.error("推送任务执行失败", e);
                } finally {
                    latch.countDown();
                }
            });
        }
        
        try {
            // 等待所有推送完成(最多等待10秒)
            latch.await(10, TimeUnit.SECONDS);
            
            long totalUsers = batch.stream()
                .mapToLong(t -> t.getUserIds().size())
                .sum();
            log.info("批量推送完成: 任务数={}, 用户数={}", batch.size(), totalUsers);
            
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
    }
    
    /**
     * 推送给多个用户
     */
    private void pushToUsers(List<Long> userIds, Message message) {
        String json;
        try {
            json = objectMapper.writeValueAsString(message);
        } catch (Exception e) {
            log.error("消息序列化失败", e);
            return;
        }
        
        TextWebSocketFrame frame = new TextWebSocketFrame(json);
        
        for (Long userId : userIds) {
            List<Channel> channels = connectionManager.getUserChannels(userId);
            for (Channel channel : channels) {
                if (channel.isActive()) {
                    channel.writeAndFlush(frame.copy());
                }
            }
        }
    }
    
    /**
     * 获取推送队列大小
     */
    public int getQueueSize() {
        return pushQueue.size();
    }
    
    /**
     * 推送任务
     */
    @lombok.Data
    private static class PushTask {
        private List<Long> userIds;
        private Message message;
        private long createTime;
    }
}
推送限流
package com.example.push.limiter;

import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;

import java.util.concurrent.TimeUnit;

/**
 * 推送限流器
 */
@Slf4j
@Component
public class PushRateLimiter {
    
    private final StringRedisTemplate redisTemplate;
    
    public PushRateLimiter(StringRedisTemplate redisTemplate) {
        this.redisTemplate = redisTemplate;
    }
    
    /**
     * 用户接收限流(滑动窗口)
     * @param userId 用户ID
     * @param limit 时间窗口内最大接收条数
     * @param windowSeconds 时间窗口(秒)
     */
    public boolean allowReceive(Long userId, int limit, int windowSeconds) {
        String key = "push:limit:user:" + userId;
        long now = System.currentTimeMillis();
        long windowStart = now - windowSeconds * 1000L;
        
        // 移除窗口外的记录
        redisTemplate.opsForZSet().removeRangeByScore(key, 0, windowStart);
        
        // 统计当前窗口内的接收次数
        Long count = redisTemplate.opsForZSet().zCard(key);
        
        if (count != null && count >= limit) {
            log.warn("用户接收限流: userId={}, count={}", userId, count);
            return false;
        }
        
        // 添加本次接收记录
        redisTemplate.opsForZSet().add(key, String.valueOf(now), now);
        redisTemplate.expire(key, windowSeconds, TimeUnit.SECONDS);
        
        return true;
    }
    
    /**
     * 系统级推送限流(令牌桶)
     * @param rate 每秒推送速率
     */
    public boolean allowSystemPush(int rate) {
        String key = "push:limit:system";
        String tokensKey = key + ":tokens";
        String lastRefillKey = key + ":last_refill";
        
        long now = System.currentTimeMillis();
        
        String tokensStr = redisTemplate.opsForValue().get(tokensKey);
        String lastRefillStr = redisTemplate.opsForValue().get(lastRefillKey);
        
        int tokens = tokensStr == null ? rate : Integer.parseInt(tokensStr);
        long lastRefill = lastRefillStr == null ? now : Long.parseLong(lastRefillStr);
        
        // 补充令牌
        long elapsed = now - lastRefill;
        int tokensToAdd = (int) (elapsed * rate / 1000);
        tokens = Math.min(rate, tokens + tokensToAdd);
        
        // 尝试获取令牌
        if (tokens > 0) {
            tokens--;
            redisTemplate.opsForValue().set(tokensKey, String.valueOf(tokens));
            redisTemplate.opsForValue().set(lastRefillKey, String.valueOf(now));
            return true;
        }
        
        return false;
    }
}

三、可运行Java代码示例

完整的推送服务实现

package com.example.push;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.scheduling.annotation.EnableScheduling;

/**
 * 消息推送服务启动类
 */
@SpringBootApplication
@EnableScheduling
public class PushServiceApplication {
    
    public static void main(String[] args) {
        SpringApplication.run(PushServiceApplication.class, args);
    }
}
package com.example.push.service;

import com.example.push.batch.BatchPushManager;
import com.example.push.connection.ConnectionManager;
import com.example.push.limiter.PushRateLimiter;
import com.example.push.protocol.Message;
import com.example.push.protocol.MessageType;
import com.example.push.reliability.MessageReliabilityManager;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import java.util.List;

/**
 * 推送服务API
 */
@Slf4j
@Service
public class PushService {
    
    @Autowired
    private ConnectionManager connectionManager;
    
    @Autowired
    private BatchPushManager batchPushManager;
    
    @Autowired
    private PushRateLimiter rateLimiter;
    
    @Autowired
    private MessageReliabilityManager reliabilityManager;
    
    /**
     * 发送私聊消息
     */
    public void sendPrivateMessage(Long fromUserId, Long toUserId, String content) {
        Message message = Message.createChatMessage(fromUserId, toUserId, content);
        
        // 限流检查
        if (!rateLimiter.allowReceive(toUserId, 100, 60)) {
            log.warn("消息发送被限流: toUserId={}", toUserId);
            throw new RuntimeException("发送频率过高");
        }
        
        // 判断接收方是否在线
        if (connectionManager.isOnline(toUserId)) {
            // 在线推送
            connectionManager.getUserChannels(toUserId).forEach(channel -> {
                if (channel.isActive()) {
                    sendToChannel(channel, message);
                }
            });
            
            // 添加到可靠性管理
            reliabilityManager.addPendingMessage(message, toUserId);
            
        } else {
            // 离线存储
            // TODO: 调用离线消息存储
        }
    }
    
    /**
     * 批量推送
     */
    public void batchPush(List<Long> userIds, String title, String content) {
        Message message = new Message();
        message.setMessageId(SnowflakeIdGenerator.nextId());
        message.setType(MessageType.SYSTEM);
        message.setContent(content);
        
        batchPushManager.addPushTask(userIds, message);
    }
    
    /**
     * 发送到Channel
     */
    private void sendToChannel(io.netty.channel.Channel channel, Message message) {
        // 实际发送逻辑
    }
}
package com.example.push.controller;

import com.example.push.dto.PushRequest;
import com.example.push.service.PushService;
import org.springframework.web.bind.annotation.*;

/**
 * 推送接口控制器
 */
@RestController
@RequestMapping("/api/push")
public class PushController {
    
    private final PushService pushService;
    
    public PushController(PushService pushService) {
        this.pushService = pushService;
    }
    
    /**
     * 发送私聊消息
     */
    @PostMapping("/private")
    public Result<Void> sendPrivate(@RequestBody PushRequest request) {
        pushService.sendPrivateMessage(
            request.getFromUserId(),
            request.getToUserId(),
            request.getContent()
        );
        return Result.success();
    }
    
    /**
     * 批量推送
     */
    @PostMapping("/batch")
    public Result<Void> batchPush(@RequestBody PushRequest request) {
        pushService.batchPush(
            request.getUserIds(),
            request.getTitle(),
            request.getContent()
        );
        return Result.success();
    }
}

四、实战应用场景

场景一:即时通讯

/**
 * 即时通讯服务
 */
@Service
public class ImService {
    
    @Autowired
    private PushService pushService;
    
    @Autowired
    private MessageStore messageStore;
    
    /**
     * 发送单聊消息
     */
    public Long sendSingleChat(Long fromUserId, Long toUserId, String content, Integer msgType) {
        // 1. 创建消息
        ImMessage message = new ImMessage();
        message.setMessageId(SnowflakeIdGenerator.nextId());
        message.setFromUserId(fromUserId);
        message.setToUserId(toUserId);
        message.setContent(content);
        message.setMsgType(msgType);
        message.setCreateTime(LocalDateTime.now());
        message.setStatus(0); // 未读
        
        // 2. 保存消息
        messageStore.saveMessage(message);
        
        // 3. 推送消息
        pushService.sendPrivateMessage(fromUserId, toUserId, content);
        
        // 4. 更新会话列表
        updateConversation(fromUserId, toUserId, message);
        
        return message.getMessageId();
    }
    
    /**
     * 发送群聊消息
     */
    public Long sendGroupChat(Long fromUserId, Long groupId, String content) {
        // 1. 获取群成员
        List<Long> memberIds = getGroupMembers(groupId);
        
        // 2. 创建并保存消息
        ImMessage message = new ImMessage();
        message.setMessageId(SnowflakeIdGenerator.nextId());
        message.setFromUserId(fromUserId);
        message.setGroupId(groupId);
        message.setContent(content);
        message.setCreateTime(LocalDateTime.now());
        
        messageStore.saveMessage(message);
        
        // 3. 批量推送给群成员
        List<Long> receivers = memberIds.stream()
            .filter(id -> !id.equals(fromUserId))
            .collect(Collectors.toList());
        
        pushService.batchPush(receivers, null, content);
        
        return message.getMessageId();
    }
}

场景二:系统通知

/**
 * 系统通知服务
 */
@Service
public class NotificationService {
    
    @Autowired
    private PushService pushService;
    
    @Autowired
    private UserDeviceService userDeviceService;
    
    /**
     * 发送系统通知
     */
    public void sendNotification(Long userId, String title, String content, 
                                 NotificationType type) {
        Notification notification = new Notification();
        notification.setUserId(userId);
        notification.setTitle(title);
        notification.setContent(content);
        notification.setType(type);
        notification.setCreateTime(LocalDateTime.now());
        notification.setReadStatus(0);
        
        // 保存通知
        saveNotification(notification);
        
        // 在线推送
        pushService.sendPrivateMessage(0L, userId, content);
        
        // 离线推送(APNs/FCM)
        if (!isUserOnline(userId)) {
            pushToMobile(userId, title, content);
        }
    }
    
    /**
     * 推送到移动端
     */
    private void pushToMobile(Long userId, String title, String content) {
        List<UserDevice> devices = userDeviceService.getUserDevices(userId);
        
        for (UserDevice device : devices) {
            if (device.getPlatform() == Platform.IOS) {
                // APNs推送
                apnsPush(device.getDeviceToken(), title, content);
            } else if (device.getPlatform() == Platform.ANDROID) {
                // FCM推送
                fcmPush(device.getDeviceToken(), title, content);
            }
        }
    }
    
    /**
     * APNs推送
     */
    private void apnsPush(String deviceToken, String title, String content) {
        // TODO: 实现APNs推送
    }
    
    /**
     * FCM推送
     */
    private void fcmPush(String deviceToken, String title, String content) {
        // TODO: 实现FCM推送
    }
}

五、总结与最佳实践

5.1 架构设计要点

设计维度推荐方案说明
连接管理Netty + 自定义协议高性能、易扩展
消息路由Redis Pub/Sub支持集群消息分发
可靠性保证ACK机制 + 重传确保消息送达
离线存储Redis有序集合按时间排序、自动过期
性能优化批量推送 + 限流提升吞吐量

5.2 性能优化建议

  1. 连接池优化:合理配置Netty线程池、连接参数
  2. 消息压缩:大消息体使用gzip压缩
  3. 批量操作:消息批量推送、批量存储
  4. 异步处理:消息处理异步化、非阻塞IO

5.3 可靠性保障

/**
 * 消息可靠性配置
 */
@Configuration
public class ReliabilityConfig {
    
    /**
     * 消息重试策略
     */
    @Bean
    public RetryTemplate retryTemplate() {
        RetryTemplate template = new RetryTemplate();
        
        // 指数退避策略
        ExponentialBackOffPolicy backOffPolicy = new ExponentialBackOffPolicy();
        backOffPolicy.setInitialInterval(1000);
        backOffPolicy.setMultiplier(2.0);
        backOffPolicy.setMaxInterval(10000);
        
        template.setBackOffPolicy(backOffPolicy);
        
        // 简单重试策略
        SimpleRetryPolicy retryPolicy = new SimpleRetryPolicy();
        retryPolicy.setMaxAttempts(3);
        template.setRetryPolicy(retryPolicy);
        
        return template;
    }
}

5.4 监控指标

  • 连接数:在线用户数、总连接数
  • 消息量:发送QPS、接收QPS
  • 延迟:消息送达延迟、平均延迟
  • 可靠性:ACK确认率、重传率

5.5 扩展性考虑

  1. 水平扩展:支持多服务器部署
  2. 消息队列:集成Kafka/RocketMQ削峰
  3. 多协议支持:WebSocket、TCP、MQTT
  4. 国际化:多时区、多语言支持

六、思考与练习

思考题

  1. 基础题:WebSocket与HTTP长轮询相比有哪些优势?在什么场景下应该选择WebSocket?
  2. 进阶题:如何保证消息的可靠送达?ACK机制和重传策略如何配合?
  3. 实战题:设计一个支持百万级同时在线的IM系统架构,包括连接管理、消息路由、集群扩展。

编程练习

练习:实现一个简易即时通讯系统,要求:

  1. 基于Netty实现WebSocket服务端
  2. 支持单聊、群聊消息
  3. 实现消息ACK机制和离线消息存储
  4. 支持心跳检测和断线重连

章节关联

  • 前置章节:《短链系统设计详解》
  • 后续章节:《订单系统设计详解》
  • 扩展阅读:《Netty实战》、WebSocket RFC 6455规范

📝 下一章预告

下一章将学习订单系统的设计,深入探讨订单状态机、幂等设计、分布式事务、库存扣减等电商核心问题。


本章完