消息推送系统详解
本章导读
消息推送系统是现代互联网应用的核心基础设施,负责将消息实时、可靠地送达给用户。本章将深入剖析长连接管理、消息路由策略、送达保证机制、海量推送优化等核心问题,帮助你掌握IM系统和推送通知系统的设计精髓。
学习目标:
- 目标1:掌握WebSocket长连接管理与消息协议设计
- 目标2:理解消息路由、ACK机制、离线存储的核心原理
- 目标3:能够设计高并发、高可用的消息推送服务
前置知识:熟悉Netty框架,了解TCP/WebSocket协议,掌握Redis基础
阅读时长:约 55 分钟
一、知识概述
消息推送系统是现代互联网应用的核心基础设施之一,负责将消息实时、可靠地送达给用户。无论是即时通讯、社交动态、系统通知,还是营销推送,都离不开消息推送系统的支持。
本文将深入剖析消息推送系统的设计要点,包括长连接管理、消息路由策略、送达保证机制、海量推送优化等核心问题,并提供完整的Java实现方案。
消息推送系统的核心挑战
- 实时性:消息需要毫秒级送达
- 可靠性:确保消息不丢失
- 高并发:支持海量用户同时在线
- 多端同步:手机、Web、桌面多端消息同步
- 离线推送:用户离线时的消息送达
二、知识点详细讲解
2.1 长连接管理
WebSocket连接管理
package com.example.push.connection;
import io.netty.channel.Channel;
import io.netty.channel.ChannelId;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.util.concurrent.GlobalEventExecutor;
import org.springframework.stereotype.Component;
import java.util.concurrent.ConcurrentHashMap;
/**
* WebSocket连接管理器
* 管理所有客户端的长连接
*/
@Component
public class ConnectionManager {
// 所有连接的Channel组
private final ChannelGroup channels = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
// 用户ID -> Channel映射(支持多设备登录)
private final ConcurrentHashMap<Long, ConcurrentHashMap<String, Channel>> userChannels =
new ConcurrentHashMap<>();
// ChannelId -> 用户信息映射
private final ConcurrentHashMap<ChannelId, UserInfo> channelUsers =
new ConcurrentHashMap<>();
/**
* 添加连接
* @param userId 用户ID
* @param deviceId 设备ID
* @param channel WebSocket Channel
*/
public void addConnection(Long userId, String deviceId, Channel channel) {
// 添加到全局Channel组
channels.add(channel);
// 用户Channel映射
userChannels.computeIfAbsent(userId, k -> new ConcurrentHashMap<>())
.put(deviceId, channel);
// Channel到用户的映射
channelUsers.put(channel.id(), new UserInfo(userId, deviceId));
System.out.println(String.format(
"用户上线: userId=%d, deviceId=%s, channelId=%s, 当前在线用户数=%d",
userId, deviceId, channel.id(), userChannels.size()
));
}
/**
* 移除连接
* @param channel WebSocket Channel
*/
public void removeConnection(Channel channel) {
UserInfo userInfo = channelUsers.remove(channel.id());
if (userInfo != null) {
Long userId = userInfo.getUserId();
String deviceId = userInfo.getDeviceId();
ConcurrentHashMap<String, Channel> deviceMap = userChannels.get(userId);
if (deviceMap != null) {
deviceMap.remove(deviceId);
if (deviceMap.isEmpty()) {
userChannels.remove(userId);
}
}
System.out.println(String.format(
"用户下线: userId=%d, deviceId=%s, 当前在线用户数=%d",
userId, deviceId, userChannels.size()
));
}
channels.remove(channel);
}
/**
* 获取用户的所有连接(多设备)
*/
public List<Channel> getUserChannels(Long userId) {
ConcurrentHashMap<String, Channel> deviceMap = userChannels.get(userId);
if (deviceMap == null) {
return Collections.emptyList();
}
return new ArrayList<>(deviceMap.values());
}
/**
* 获取用户的指定设备连接
*/
public Channel getUserChannel(Long userId, String deviceId) {
ConcurrentHashMap<String, Channel> deviceMap = userChannels.get(userId);
return deviceMap == null ? null : deviceMap.get(deviceId);
}
/**
* 判断用户是否在线
*/
public boolean isOnline(Long userId) {
ConcurrentHashMap<String, Channel> deviceMap = userChannels.get(userId);
return deviceMap != null && !deviceMap.isEmpty();
}
/**
* 获取在线用户数
*/
public int getOnlineUserCount() {
return userChannels.size();
}
/**
* 获取总连接数
*/
public int getTotalConnectionCount() {
return channels.size();
}
/**
* 获取所有连接(用于广播)
*/
public ChannelGroup getAllChannels() {
return channels;
}
/**
* 用户信息内部类
*/
@lombok.Data
@lombok.AllArgsConstructor
private static class UserInfo {
private Long userId;
private String deviceId;
}
}
Netty WebSocket服务器
package com.example.push.server;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.handler.timeout.IdleStateHandler;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import java.util.concurrent.TimeUnit;
/**
* WebSocket服务器
*/
@Component
public class WebSocketServer {
@Value("${websocket.port:8088}")
private int port;
private EventLoopGroup bossGroup;
private EventLoopGroup workerGroup;
private Channel serverChannel;
@PostConstruct
public void start() throws InterruptedException {
bossGroup = new NioEventLoopGroup(1);
workerGroup = new NioEventLoopGroup();
ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.option(ChannelOption.SO_BACKLOG, 1024)
.childOption(ChannelOption.SO_KEEPALIVE, true)
.childOption(ChannelOption.TCP_NODELAY, true)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) {
ChannelPipeline pipeline = ch.pipeline();
// HTTP编解码器
pipeline.addLast("http-codec", new HttpServerCodec());
// 分块写入处理器
pipeline.addLast("http-chunked", new ChunkedWriteHandler());
// HTTP消息聚合器
pipeline.addLast("aggregator", new HttpObjectAggregator(65536));
// 心跳检测
pipeline.addLast("idle-state", new IdleStateHandler(
60, 0, 0, TimeUnit.SECONDS));
// WebSocket协议处理器
pipeline.addLast("webSocket-protocol",
new WebSocketServerProtocolHandler("/ws"));
// 自定义消息处理器
pipeline.addLast("message-handler", new WebSocketMessageHandler());
}
});
// 绑定端口
serverChannel = bootstrap.bind(port).sync().channel();
System.out.println("WebSocket服务器启动成功,端口: " + port);
}
@PreDestroy
public void stop() {
if (serverChannel != null) {
serverChannel.close();
}
if (bossGroup != null) {
bossGroup.shutdownGracefully();
}
if (workerGroup != null) {
workerGroup.shutdownGracefully();
}
System.out.println("WebSocket服务器已关闭");
}
}
消息处理器
package com.example.push.server;
import com.example.push.connection.ConnectionManager;
import com.example.push.protocol.Message;
import com.example.push.protocol.MessageType;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.timeout.IdleStateEvent;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
/**
* WebSocket消息处理器
*/
@Slf4j
@ChannelHandler.Sharable
@Component
public class WebSocketMessageHandler extends SimpleChannelInboundHandler<TextWebSocketFrame> {
@Autowired
private ConnectionManager connectionManager;
@Autowired
private MessageDispatcher messageDispatcher;
private final ObjectMapper objectMapper = new ObjectMapper();
/**
* 连接建立
*/
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
// WebSocket握手完成
log.info("WebSocket连接建立: {}", ctx.channel().id());
} else if (evt instanceof IdleStateEvent) {
// 心跳超时
log.warn("心跳超时,关闭连接: {}", ctx.channel().id());
ctx.close();
}
}
/**
* 接收消息
*/
@Override
protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame frame) throws Exception {
String text = frame.text();
log.debug("收到消息: {}", text);
try {
Message message = objectMapper.readValue(text, Message.class);
switch (message.getType()) {
case LOGIN:
handleLogin(ctx, message);
break;
case HEARTBEAT:
handleHeartbeat(ctx, message);
break;
case CHAT:
handleChat(ctx, message);
break;
case ACK:
handleAck(ctx, message);
break;
default:
log.warn("未知消息类型: {}", message.getType());
}
} catch (Exception e) {
log.error("消息处理异常", e);
sendError(ctx, "消息格式错误");
}
}
/**
* 处理登录
*/
private void handleLogin(ChannelHandlerContext ctx, Message message) {
Long userId = message.getFromUserId();
String deviceId = message.getDeviceId();
// 添加连接映射
connectionManager.addConnection(userId, deviceId, ctx.channel());
// 发送登录成功响应
Message response = new Message();
response.setType(MessageType.LOGIN_ACK);
response.setContent("登录成功");
sendMessage(ctx, response);
// 推送离线消息
messageDispatcher.pushOfflineMessages(userId, deviceId);
}
/**
* 处理心跳
*/
private void handleHeartbeat(ChannelHandlerContext ctx, Message message) {
Message response = new Message();
response.setType(MessageType.HEARTBEAT_ACK);
sendMessage(ctx, response);
}
/**
* 处理聊天消息
*/
private void handleChat(ChannelHandlerContext ctx, Message message) {
// 分发消息
messageDispatcher.dispatch(message);
}
/**
* 处理ACK确认
*/
private void handleAck(ChannelHandlerContext ctx, Message message) {
Long messageId = message.getMessageId();
messageDispatcher.handleAck(messageId);
}
/**
* 连接断开
*/
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
connectionManager.removeConnection(ctx.channel());
super.channelInactive(ctx);
}
/**
* 异常处理
*/
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
log.error("连接异常", cause);
ctx.close();
}
/**
* 发送消息
*/
private void sendMessage(ChannelHandlerContext ctx, Message message) {
try {
String json = objectMapper.writeValueAsString(message);
ctx.writeAndFlush(new TextWebSocketFrame(json));
} catch (Exception e) {
log.error("发送消息失败", e);
}
}
/**
* 发送错误消息
*/
private void sendError(ChannelHandlerContext ctx, String error) {
Message message = new Message();
message.setType(MessageType.ERROR);
message.setContent(error);
sendMessage(ctx, message);
}
}
2.2 消息协议设计
package com.example.push.protocol;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.AllArgsConstructor;
import java.time.LocalDateTime;
/**
* 统一消息协议
*/
@Data
@NoArgsConstructor
@AllArgsConstructor
public class Message {
/**
* 消息ID(全局唯一)
*/
private Long messageId;
/**
* 消息类型
*/
private MessageType type;
/**
* 发送方用户ID
*/
private Long fromUserId;
/**
* 接收方用户ID(私聊)
*/
private Long toUserId;
/**
* 群组ID(群聊)
*/
private Long groupId;
/**
* 设备ID
*/
private String deviceId;
/**
* 消息内容
*/
private String content;
/**
* 扩展字段(JSON格式)
*/
private String extra;
/**
* 消息时间戳
*/
private LocalDateTime timestamp;
/**
* 消息状态
*/
private MessageStatus status;
/**
* 创建聊天消息
*/
public static Message createChatMessage(Long fromUserId, Long toUserId, String content) {
Message message = new Message();
message.setMessageId(SnowflakeIdGenerator.nextId());
message.setType(MessageType.CHAT);
message.setFromUserId(fromUserId);
message.setToUserId(toUserId);
message.setContent(content);
message.setTimestamp(LocalDateTime.now());
message.setStatus(MessageStatus.SENDING);
return message;
}
/**
* 创建ACK确认消息
*/
public static Message createAck(Long messageId) {
Message message = new Message();
message.setType(MessageType.ACK);
message.setMessageId(messageId);
message.setTimestamp(LocalDateTime.now());
return message;
}
}
/**
* 消息类型枚举
*/
public enum MessageType {
LOGIN, // 登录
LOGIN_ACK, // 登录响应
HEARTBEAT, // 心跳
HEARTBEAT_ACK, // 心跳响应
CHAT, // 聊天消息
GROUP_CHAT, // 群聊消息
SYSTEM, // 系统消息
ACK, // 确认消息
ERROR // 错误消息
}
/**
* 消息状态枚举
*/
public enum MessageStatus {
SENDING, // 发送中
DELIVERED, // 已送达
READ, // 已读
FAILED // 发送失败
}
/**
* 雪花ID生成器
*/
public class SnowflakeIdGenerator {
private static final long twepoch = 1704038400000L;
private static final long sequenceBits = 13L;
private static final long sequenceMask = ~(-1L << sequenceBits);
private static long sequence = 0L;
private static long lastTimestamp = -1L;
public static synchronized long nextId() {
long timestamp = System.currentTimeMillis();
if (timestamp < lastTimestamp) {
throw new RuntimeException("Clock moved backwards");
}
if (timestamp == lastTimestamp) {
sequence = (sequence + 1) & sequenceMask;
if (sequence == 0) {
timestamp = tilNextMillis(lastTimestamp);
}
} else {
sequence = 0L;
}
lastTimestamp = timestamp;
return ((timestamp - twepoch) << sequenceBits) | sequence;
}
private static long tilNextMillis(long lastTimestamp) {
long timestamp = System.currentTimeMillis();
while (timestamp <= lastTimestamp) {
timestamp = System.currentTimeMillis();
}
return timestamp;
}
}
2.3 消息路由策略
package com.example.push.router;
import com.example.push.connection.ConnectionManager;
import com.example.push.protocol.Message;
import com.example.push.protocol.MessageType;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.netty.channel.Channel;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;
import java.util.List;
import java.util.concurrent.TimeUnit;
/**
* 消息分发器
*/
@Slf4j
@Component
public class MessageDispatcher {
@Autowired
private ConnectionManager connectionManager;
@Autowired
private StringRedisTemplate redisTemplate;
@Autowired
private OfflineMessageStore offlineMessageStore;
private final ObjectMapper objectMapper = new ObjectMapper();
// Redis消息队列key
private static final String MSG_QUEUE_PREFIX = "msg:queue:";
// 消息ACK等待队列
private static final String ACK_WAITING_QUEUE = "msg:ack:waiting";
/**
* 分发消息
*/
public void dispatch(Message message) {
switch (message.getType()) {
case CHAT:
dispatchPrivateMessage(message);
break;
case GROUP_CHAT:
dispatchGroupMessage(message);
break;
case SYSTEM:
dispatchSystemMessage(message);
break;
default:
log.warn("无法分发的消息类型: {}", message.getType());
}
}
/**
* 分发私聊消息
*/
private void dispatchPrivateMessage(Message message) {
Long toUserId = message.getToUserId();
// 1. 保存消息(用于多端同步和离线推送)
saveMessage(message);
// 2. 判断接收方是否在线
if (connectionManager.isOnline(toUserId)) {
// 在线推送
List<Channel> channels = connectionManager.getUserChannels(toUserId);
for (Channel channel : channels) {
if (channel.isActive()) {
sendToChannel(channel, message);
// 加入ACK等待队列
addAckWaiting(message.getMessageId(), toUserId);
}
}
// 3. 推送通知到其他服务器(集群场景)
publishToCluster(message);
} else {
// 离线存储
offlineMessageStore.storeOfflineMessage(toUserId, message);
// 触发离线推送(APNs、FCM等)
triggerOfflinePush(toUserId, message);
}
}
/**
* 分发群聊消息
*/
private void dispatchGroupMessage(Message message) {
Long groupId = message.getGroupId();
// 1. 获取群成员列表
List<Long> memberIds = getGroupMembers(groupId);
// 2. 保存消息
saveMessage(message);
// 3. 批量推送给在线成员
for (Long memberId : memberIds) {
if (!memberId.equals(message.getFromUserId())) {
if (connectionManager.isOnline(memberId)) {
List<Channel> channels = connectionManager.getUserChannels(memberId);
channels.forEach(channel -> sendToChannel(channel, message));
} else {
offlineMessageStore.storeOfflineMessage(memberId, message);
}
}
}
}
/**
* 分发系统消息
*/
private void dispatchSystemMessage(Message message) {
// 全员广播
connectionManager.getAllChannels().forEach(channel -> {
sendToChannel(channel, message);
});
}
/**
* 推送离线消息
*/
public void pushOfflineMessages(Long userId, String deviceId) {
List<Message> offlineMessages = offlineMessageStore.getOfflineMessages(userId);
for (Message message : offlineMessages) {
Channel channel = connectionManager.getUserChannel(userId, deviceId);
if (channel != null && channel.isActive()) {
sendToChannel(channel, message);
}
}
log.info("推送离线消息: userId={}, count={}", userId, offlineMessages.size());
}
/**
* 处理ACK确认
*/
public void handleAck(Long messageId) {
// 从等待队列移除
String key = ACK_WAITING_QUEUE + ":" + messageId;
redisTemplate.delete(key);
// 更新消息状态
updateMessageStatus(messageId, "DELIVERED");
log.debug("消息ACK确认: messageId={}", messageId);
}
/**
* 发送到Channel
*/
private void sendToChannel(Channel channel, Message message) {
try {
String json = objectMapper.writeValueAsString(message);
channel.writeAndFlush(new TextWebSocketFrame(json));
} catch (Exception e) {
log.error("发送消息失败: channelId={}", channel.id(), e);
}
}
/**
* 保存消息到数据库
*/
private void saveMessage(Message message) {
String key = "msg:store:" + message.getMessageId();
try {
String json = objectMapper.writeValueAsString(message);
redisTemplate.opsForValue().set(key, json, 7, TimeUnit.DAYS);
} catch (Exception e) {
log.error("保存消息失败", e);
}
}
/**
* 添加到ACK等待队列
*/
private void addAckWaiting(Long messageId, Long userId) {
String key = ACK_WAITING_QUEUE + ":" + messageId;
redisTemplate.opsForValue().set(key, userId.toString(), 1, TimeUnit.MINUTES);
}
/**
* 发布到集群
*/
private void publishToCluster(Message message) {
try {
String json = objectMapper.writeValueAsString(message);
redisTemplate.convertAndSend("push:channel", json);
} catch (Exception e) {
log.error("发布到集群失败", e);
}
}
/**
* 获取群成员列表
*/
private List<Long> getGroupMembers(Long groupId) {
// 实际实现从数据库或缓存获取
return List.of(1001L, 1002L, 1003L);
}
/**
* 更新消息状态
*/
private void updateMessageStatus(Long messageId, String status) {
// 实际实现更新数据库
}
/**
* 触发离线推送
*/
private void triggerOfflinePush(Long userId, Message message) {
// 调用APNs、FCM等推送服务
log.info("触发离线推送: userId={}, messageId={}", userId, message.getMessageId());
}
}
2.4 消息可靠性保证
消息ACK机制
package com.example.push.reliability;
import com.example.push.protocol.Message;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* 消息可靠性管理器
* 实现消息的重传和确认机制
*/
@Slf4j
@Component
public class MessageReliabilityManager {
// 等待ACK的消息队列
private final Map<Long, PendingMessage> pendingMessages = new ConcurrentHashMap<>();
// 最大重试次数
private static final int MAX_RETRY_COUNT = 3;
// 重试间隔(毫秒)
private static final long RETRY_INTERVAL_MS = 5000;
/**
* 添加待确认消息
*/
public void addPendingMessage(Message message, Long userId) {
PendingMessage pending = new PendingMessage();
pending.setMessage(message);
pending.setUserId(userId);
pending.setSendTime(System.currentTimeMillis());
pending.setRetryCount(0);
pendingMessages.put(message.getMessageId(), pending);
}
/**
* 确认消息送达
*/
public void acknowledgeMessage(Long messageId) {
PendingMessage pending = pendingMessages.remove(messageId);
if (pending != null) {
log.debug("消息已确认: messageId={}, 耗时={}ms",
messageId, System.currentTimeMillis() - pending.getSendTime());
}
}
/**
* 定时重传未确认消息
*/
@Scheduled(fixedDelay = 5000)
public void retryPendingMessages() {
long now = System.currentTimeMillis();
pendingMessages.forEach((messageId, pending) -> {
if (now - pending.getSendTime() >= RETRY_INTERVAL_MS) {
if (pending.getRetryCount() >= MAX_RETRY_COUNT) {
// 超过最大重试次数,移除并记录
pendingMessages.remove(messageId);
log.warn("消息重试失败,已丢弃: messageId={}, retryCount={}",
messageId, pending.getRetryCount());
return;
}
// 重传消息
retryMessage(pending);
// 更新重试次数和时间
pending.setRetryCount(pending.getRetryCount() + 1);
pending.setSendTime(now);
}
});
}
/**
* 重传消息
*/
private void retryMessage(PendingMessage pending) {
// 实际实现:重新发送消息
log.info("重传消息: messageId={}, retryCount={}",
pending.getMessage().getMessageId(), pending.getRetryCount());
}
/**
* 获取等待确认的消息数量
*/
public int getPendingCount() {
return pendingMessages.size();
}
/**
* 待确认消息实体
*/
@lombok.Data
private static class PendingMessage {
private Message message;
private Long userId;
private long sendTime;
private int retryCount;
}
}
离线消息存储
package com.example.push.store;
import com.example.push.protocol.Message;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;
/**
* 离线消息存储器
*/
@Slf4j
@Component
public class OfflineMessageStore {
@Autowired
private StringRedisTemplate redisTemplate;
private final ObjectMapper objectMapper = new ObjectMapper();
// 离线消息队列前缀
private static final String OFFLINE_QUEUE_PREFIX = "offline:msg:";
// 单个用户最大离线消息数
private static final int MAX_OFFLINE_MESSAGES = 1000;
// 离线消息过期时间(天)
private static final long OFFLINE_EXPIRE_DAYS = 7;
/**
* 存储离线消息
*/
public void storeOfflineMessage(Long userId, Message message) {
String key = OFFLINE_QUEUE_PREFIX + userId;
try {
String json = objectMapper.writeValueAsString(message);
// 添加到有序集合(按时间戳排序)
double score = message.getTimestamp().atZone(java.time.ZoneId.systemDefault())
.toInstant().toEpochMilli();
redisTemplate.opsForZSet().add(key, json, score);
// 限制离线消息数量
Long size = redisTemplate.opsForZSet().zCard(key);
if (size != null && size > MAX_OFFLINE_MESSAGES) {
redisTemplate.opsForZSet().removeRange(key, 0, size - MAX_OFFLINE_MESSAGES - 1);
}
// 设置过期时间
redisTemplate.expire(key, OFFLINE_EXPIRE_DAYS, TimeUnit.DAYS);
log.debug("存储离线消息: userId={}, messageId={}", userId, message.getMessageId());
} catch (Exception e) {
log.error("存储离线消息失败: userId={}", userId, e);
}
}
/**
* 获取离线消息
*/
public List<Message> getOfflineMessages(Long userId) {
String key = OFFLINE_QUEUE_PREFIX + userId;
List<Message> messages = new ArrayList<>();
try {
Set<String> jsonSet = redisTemplate.opsForZSet().range(key, 0, -1);
if (jsonSet != null) {
for (String json : jsonSet) {
Message message = objectMapper.readValue(json, Message.class);
messages.add(message);
}
}
} catch (Exception e) {
log.error("获取离线消息失败: userId={}", userId, e);
}
return messages;
}
/**
* 删除已读的离线消息
*/
public void removeOfflineMessages(Long userId, List<Long> messageIds) {
String key = OFFLINE_QUEUE_PREFIX + userId;
try {
Set<String> jsonSet = redisTemplate.opsForZSet().range(key, 0, -1);
if (jsonSet != null) {
for (String json : jsonSet) {
Message message = objectMapper.readValue(json, Message.class);
if (messageIds.contains(message.getMessageId())) {
redisTemplate.opsForZSet().remove(key, json);
}
}
}
} catch (Exception e) {
log.error("删除离线消息失败: userId={}", userId, e);
}
}
/**
* 清空离线消息
*/
public void clearOfflineMessages(Long userId) {
String key = OFFLINE_QUEUE_PREFIX + userId;
redisTemplate.delete(key);
}
/**
* 获取离线消息数量
*/
public long getOfflineMessageCount(Long userId) {
String key = OFFLINE_QUEUE_PREFIX + userId;
Long count = redisTemplate.opsForZSet().zCard(key);
return count == null ? 0 : count;
}
}
2.5 海量推送优化
批量推送优化
package com.example.push.batch;
import com.example.push.connection.ConnectionManager;
import com.example.push.protocol.Message;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.netty.channel.Channel;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.List;
import java.util.concurrent.*;
/**
* 批量推送管理器
* 优化大规模推送性能
*/
@Slf4j
@Component
public class BatchPushManager {
@Autowired
private ConnectionManager connectionManager;
private final ObjectMapper objectMapper = new ObjectMapper();
// 推送线程池
private final ExecutorService pushExecutor = Executors.newFixedThreadPool(
Runtime.getRuntime().availableProcessors() * 2,
new ThreadFactoryBuilder().setNameFormat("push-pool-%d").build()
);
// 批量推送队列
private final BlockingQueue<PushTask> pushQueue = new LinkedBlockingQueue<>(100000);
// 批量大小
private static final int BATCH_SIZE = 100;
/**
* 添加推送任务
*/
public void addPushTask(List<Long> userIds, Message message) {
PushTask task = new PushTask();
task.setUserIds(userIds);
task.setMessage(message);
task.setCreateTime(System.currentTimeMillis());
if (!pushQueue.offer(task)) {
log.warn("推送队列已满,丢弃任务");
}
}
/**
* 批量推送线程
*/
public void startBatchPush() {
new Thread(() -> {
while (true) {
try {
List<PushTask> batch = new ArrayList<>(BATCH_SIZE);
// 从队列中取出批量任务
PushTask task = pushQueue.poll(100, TimeUnit.MILLISECONDS);
if (task != null) {
batch.add(task);
pushQueue.drainTo(batch, BATCH_SIZE - 1);
}
if (!batch.isEmpty()) {
processBatch(batch);
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
break;
} catch (Exception e) {
log.error("批量推送异常", e);
}
}
}, "batch-push-thread").start();
}
/**
* 处理批量推送
*/
private void processBatch(List<PushTask> batch) {
CountDownLatch latch = new CountDownLatch(batch.size());
for (PushTask task : batch) {
pushExecutor.submit(() -> {
try {
pushToUsers(task.getUserIds(), task.getMessage());
} catch (Exception e) {
log.error("推送任务执行失败", e);
} finally {
latch.countDown();
}
});
}
try {
// 等待所有推送完成(最多等待10秒)
latch.await(10, TimeUnit.SECONDS);
long totalUsers = batch.stream()
.mapToLong(t -> t.getUserIds().size())
.sum();
log.info("批量推送完成: 任务数={}, 用户数={}", batch.size(), totalUsers);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
/**
* 推送给多个用户
*/
private void pushToUsers(List<Long> userIds, Message message) {
String json;
try {
json = objectMapper.writeValueAsString(message);
} catch (Exception e) {
log.error("消息序列化失败", e);
return;
}
TextWebSocketFrame frame = new TextWebSocketFrame(json);
for (Long userId : userIds) {
List<Channel> channels = connectionManager.getUserChannels(userId);
for (Channel channel : channels) {
if (channel.isActive()) {
channel.writeAndFlush(frame.copy());
}
}
}
}
/**
* 获取推送队列大小
*/
public int getQueueSize() {
return pushQueue.size();
}
/**
* 推送任务
*/
@lombok.Data
private static class PushTask {
private List<Long> userIds;
private Message message;
private long createTime;
}
}
推送限流
package com.example.push.limiter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;
import java.util.concurrent.TimeUnit;
/**
* 推送限流器
*/
@Slf4j
@Component
public class PushRateLimiter {
private final StringRedisTemplate redisTemplate;
public PushRateLimiter(StringRedisTemplate redisTemplate) {
this.redisTemplate = redisTemplate;
}
/**
* 用户接收限流(滑动窗口)
* @param userId 用户ID
* @param limit 时间窗口内最大接收条数
* @param windowSeconds 时间窗口(秒)
*/
public boolean allowReceive(Long userId, int limit, int windowSeconds) {
String key = "push:limit:user:" + userId;
long now = System.currentTimeMillis();
long windowStart = now - windowSeconds * 1000L;
// 移除窗口外的记录
redisTemplate.opsForZSet().removeRangeByScore(key, 0, windowStart);
// 统计当前窗口内的接收次数
Long count = redisTemplate.opsForZSet().zCard(key);
if (count != null && count >= limit) {
log.warn("用户接收限流: userId={}, count={}", userId, count);
return false;
}
// 添加本次接收记录
redisTemplate.opsForZSet().add(key, String.valueOf(now), now);
redisTemplate.expire(key, windowSeconds, TimeUnit.SECONDS);
return true;
}
/**
* 系统级推送限流(令牌桶)
* @param rate 每秒推送速率
*/
public boolean allowSystemPush(int rate) {
String key = "push:limit:system";
String tokensKey = key + ":tokens";
String lastRefillKey = key + ":last_refill";
long now = System.currentTimeMillis();
String tokensStr = redisTemplate.opsForValue().get(tokensKey);
String lastRefillStr = redisTemplate.opsForValue().get(lastRefillKey);
int tokens = tokensStr == null ? rate : Integer.parseInt(tokensStr);
long lastRefill = lastRefillStr == null ? now : Long.parseLong(lastRefillStr);
// 补充令牌
long elapsed = now - lastRefill;
int tokensToAdd = (int) (elapsed * rate / 1000);
tokens = Math.min(rate, tokens + tokensToAdd);
// 尝试获取令牌
if (tokens > 0) {
tokens--;
redisTemplate.opsForValue().set(tokensKey, String.valueOf(tokens));
redisTemplate.opsForValue().set(lastRefillKey, String.valueOf(now));
return true;
}
return false;
}
}
三、可运行Java代码示例
完整的推送服务实现
package com.example.push;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.scheduling.annotation.EnableScheduling;
/**
* 消息推送服务启动类
*/
@SpringBootApplication
@EnableScheduling
public class PushServiceApplication {
public static void main(String[] args) {
SpringApplication.run(PushServiceApplication.class, args);
}
}
package com.example.push.service;
import com.example.push.batch.BatchPushManager;
import com.example.push.connection.ConnectionManager;
import com.example.push.limiter.PushRateLimiter;
import com.example.push.protocol.Message;
import com.example.push.protocol.MessageType;
import com.example.push.reliability.MessageReliabilityManager;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.List;
/**
* 推送服务API
*/
@Slf4j
@Service
public class PushService {
@Autowired
private ConnectionManager connectionManager;
@Autowired
private BatchPushManager batchPushManager;
@Autowired
private PushRateLimiter rateLimiter;
@Autowired
private MessageReliabilityManager reliabilityManager;
/**
* 发送私聊消息
*/
public void sendPrivateMessage(Long fromUserId, Long toUserId, String content) {
Message message = Message.createChatMessage(fromUserId, toUserId, content);
// 限流检查
if (!rateLimiter.allowReceive(toUserId, 100, 60)) {
log.warn("消息发送被限流: toUserId={}", toUserId);
throw new RuntimeException("发送频率过高");
}
// 判断接收方是否在线
if (connectionManager.isOnline(toUserId)) {
// 在线推送
connectionManager.getUserChannels(toUserId).forEach(channel -> {
if (channel.isActive()) {
sendToChannel(channel, message);
}
});
// 添加到可靠性管理
reliabilityManager.addPendingMessage(message, toUserId);
} else {
// 离线存储
// TODO: 调用离线消息存储
}
}
/**
* 批量推送
*/
public void batchPush(List<Long> userIds, String title, String content) {
Message message = new Message();
message.setMessageId(SnowflakeIdGenerator.nextId());
message.setType(MessageType.SYSTEM);
message.setContent(content);
batchPushManager.addPushTask(userIds, message);
}
/**
* 发送到Channel
*/
private void sendToChannel(io.netty.channel.Channel channel, Message message) {
// 实际发送逻辑
}
}
package com.example.push.controller;
import com.example.push.dto.PushRequest;
import com.example.push.service.PushService;
import org.springframework.web.bind.annotation.*;
/**
* 推送接口控制器
*/
@RestController
@RequestMapping("/api/push")
public class PushController {
private final PushService pushService;
public PushController(PushService pushService) {
this.pushService = pushService;
}
/**
* 发送私聊消息
*/
@PostMapping("/private")
public Result<Void> sendPrivate(@RequestBody PushRequest request) {
pushService.sendPrivateMessage(
request.getFromUserId(),
request.getToUserId(),
request.getContent()
);
return Result.success();
}
/**
* 批量推送
*/
@PostMapping("/batch")
public Result<Void> batchPush(@RequestBody PushRequest request) {
pushService.batchPush(
request.getUserIds(),
request.getTitle(),
request.getContent()
);
return Result.success();
}
}
四、实战应用场景
场景一:即时通讯
/**
* 即时通讯服务
*/
@Service
public class ImService {
@Autowired
private PushService pushService;
@Autowired
private MessageStore messageStore;
/**
* 发送单聊消息
*/
public Long sendSingleChat(Long fromUserId, Long toUserId, String content, Integer msgType) {
// 1. 创建消息
ImMessage message = new ImMessage();
message.setMessageId(SnowflakeIdGenerator.nextId());
message.setFromUserId(fromUserId);
message.setToUserId(toUserId);
message.setContent(content);
message.setMsgType(msgType);
message.setCreateTime(LocalDateTime.now());
message.setStatus(0); // 未读
// 2. 保存消息
messageStore.saveMessage(message);
// 3. 推送消息
pushService.sendPrivateMessage(fromUserId, toUserId, content);
// 4. 更新会话列表
updateConversation(fromUserId, toUserId, message);
return message.getMessageId();
}
/**
* 发送群聊消息
*/
public Long sendGroupChat(Long fromUserId, Long groupId, String content) {
// 1. 获取群成员
List<Long> memberIds = getGroupMembers(groupId);
// 2. 创建并保存消息
ImMessage message = new ImMessage();
message.setMessageId(SnowflakeIdGenerator.nextId());
message.setFromUserId(fromUserId);
message.setGroupId(groupId);
message.setContent(content);
message.setCreateTime(LocalDateTime.now());
messageStore.saveMessage(message);
// 3. 批量推送给群成员
List<Long> receivers = memberIds.stream()
.filter(id -> !id.equals(fromUserId))
.collect(Collectors.toList());
pushService.batchPush(receivers, null, content);
return message.getMessageId();
}
}
场景二:系统通知
/**
* 系统通知服务
*/
@Service
public class NotificationService {
@Autowired
private PushService pushService;
@Autowired
private UserDeviceService userDeviceService;
/**
* 发送系统通知
*/
public void sendNotification(Long userId, String title, String content,
NotificationType type) {
Notification notification = new Notification();
notification.setUserId(userId);
notification.setTitle(title);
notification.setContent(content);
notification.setType(type);
notification.setCreateTime(LocalDateTime.now());
notification.setReadStatus(0);
// 保存通知
saveNotification(notification);
// 在线推送
pushService.sendPrivateMessage(0L, userId, content);
// 离线推送(APNs/FCM)
if (!isUserOnline(userId)) {
pushToMobile(userId, title, content);
}
}
/**
* 推送到移动端
*/
private void pushToMobile(Long userId, String title, String content) {
List<UserDevice> devices = userDeviceService.getUserDevices(userId);
for (UserDevice device : devices) {
if (device.getPlatform() == Platform.IOS) {
// APNs推送
apnsPush(device.getDeviceToken(), title, content);
} else if (device.getPlatform() == Platform.ANDROID) {
// FCM推送
fcmPush(device.getDeviceToken(), title, content);
}
}
}
/**
* APNs推送
*/
private void apnsPush(String deviceToken, String title, String content) {
// TODO: 实现APNs推送
}
/**
* FCM推送
*/
private void fcmPush(String deviceToken, String title, String content) {
// TODO: 实现FCM推送
}
}
五、总结与最佳实践
5.1 架构设计要点
| 设计维度 | 推荐方案 | 说明 |
|---|---|---|
| 连接管理 | Netty + 自定义协议 | 高性能、易扩展 |
| 消息路由 | Redis Pub/Sub | 支持集群消息分发 |
| 可靠性保证 | ACK机制 + 重传 | 确保消息送达 |
| 离线存储 | Redis有序集合 | 按时间排序、自动过期 |
| 性能优化 | 批量推送 + 限流 | 提升吞吐量 |
5.2 性能优化建议
- 连接池优化:合理配置Netty线程池、连接参数
- 消息压缩:大消息体使用gzip压缩
- 批量操作:消息批量推送、批量存储
- 异步处理:消息处理异步化、非阻塞IO
5.3 可靠性保障
/**
* 消息可靠性配置
*/
@Configuration
public class ReliabilityConfig {
/**
* 消息重试策略
*/
@Bean
public RetryTemplate retryTemplate() {
RetryTemplate template = new RetryTemplate();
// 指数退避策略
ExponentialBackOffPolicy backOffPolicy = new ExponentialBackOffPolicy();
backOffPolicy.setInitialInterval(1000);
backOffPolicy.setMultiplier(2.0);
backOffPolicy.setMaxInterval(10000);
template.setBackOffPolicy(backOffPolicy);
// 简单重试策略
SimpleRetryPolicy retryPolicy = new SimpleRetryPolicy();
retryPolicy.setMaxAttempts(3);
template.setRetryPolicy(retryPolicy);
return template;
}
}
5.4 监控指标
- 连接数:在线用户数、总连接数
- 消息量:发送QPS、接收QPS
- 延迟:消息送达延迟、平均延迟
- 可靠性:ACK确认率、重传率
5.5 扩展性考虑
- 水平扩展:支持多服务器部署
- 消息队列:集成Kafka/RocketMQ削峰
- 多协议支持:WebSocket、TCP、MQTT
- 国际化:多时区、多语言支持
六、思考与练习
思考题
- 基础题:WebSocket与HTTP长轮询相比有哪些优势?在什么场景下应该选择WebSocket?
- 进阶题:如何保证消息的可靠送达?ACK机制和重传策略如何配合?
- 实战题:设计一个支持百万级同时在线的IM系统架构,包括连接管理、消息路由、集群扩展。
编程练习
练习:实现一个简易即时通讯系统,要求:
- 基于Netty实现WebSocket服务端
- 支持单聊、群聊消息
- 实现消息ACK机制和离线消息存储
- 支持心跳检测和断线重连
章节关联
- 前置章节:《短链系统设计详解》
- 后续章节:《订单系统设计详解》
- 扩展阅读:《Netty实战》、WebSocket RFC 6455规范
📝 下一章预告
下一章将学习订单系统的设计,深入探讨订单状态机、幂等设计、分布式事务、库存扣减等电商核心问题。
本章完