分布式websocket

251 阅读2分钟

单体服务websocket出现问题

  • 一个服务部署多个结点,通过传统Map数据结构存储sessinId和websocket对象关系,和用户Id和sessionId集合关系实现数据推送,但是数据推送时,广播用户Id下所有连接对象,只能实现当前某个结点的数据推送,并不能实现所有结点的数据推送;
  • 分布式不同服务数据推送请求无法实现,因为websocket的broker连接端运行在一个服务上,无法实现不同服务数据推送。

redis的发布订阅模式解决websocket的单体推送问题

  • 需要推送的数据发布到redis的监听的topic(消息频道)上,订阅收到消息,根据实际业务场景,发送到websocket的broker上,实现消息推送

redis发布订阅配置

配置redis 消息监听的容器和消息监听器

@Component
@EnableConfigurationProperties({RedisConfigProperties.class})
public class RedisConfig {

    @Autowired
    private RedisConfigProperties clusterProperties;

    @Bean
    public RedisConfiguration getConfiguration() {
        return BaseRedisService.getRedisConfiguration(clusterProperties);
    }

    @Bean
    public JedisConnectionFactory getConnectionFactory(RedisConfiguration redisConfiguration) {
        if (redisConfiguration != null) {
            if (redisConfiguration instanceof RedisClusterConfiguration) {
                return new JedisConnectionFactory((RedisClusterConfiguration) redisConfiguration);
            } else if (redisConfiguration instanceof RedisStandaloneConfiguration) {
                return new JedisConnectionFactory((RedisStandaloneConfiguration) redisConfiguration);
            }
        }
        return null;
    }

    public RedisTemplate<Serializable, Serializable> getRedisTemplate(JedisConnectionFactory factory) {
        RedisTemplate<Serializable, Serializable> redisTemplate = new RedisTemplate<>();
        redisTemplate.setConnectionFactory(factory);
        RedisSerializer<String> redisSerializer = new StringRedisSerializer();
        redisTemplate.setKeySerializer(redisSerializer);
        redisTemplate.setHashValueSerializer(new GenericJackson2JsonRedisSerializer());
        return redisTemplate;
    }

    @Bean
    MessageListenerAdapter listenerAdapter(RedisReceiver receiver) {
        // 消息监听适配器
        return new MessageListenerAdapter(receiver, "onMessage");
    }

    @Bean
    RedisMessageListenerContainer container(RedisConnectionFactory connectionFactory,
                                            MessageListenerAdapter listenerAdapter) {
        RedisMessageListenerContainer container = new RedisMessageListenerContainer();
        container.setConnectionFactory(connectionFactory);
        // 订阅消息频道
        container.addMessageListener(listenerAdapter, new PatternTopic(Constants.REDIS_CHANNEL));
        return container;
    }

}

消息订阅者收到消息执行onMessage方法

@Component
public class RedisReceiver implements MessageListener {
    Logger log = LoggerFactory.getLogger(this.getClass());

    @Autowired
    private WebSocketBroker webSocketBroker;

    @Autowired
    private StringRedisTemplate redisTemplate;

    /**
     * 处理接收到的订阅消息
     */
    @Override
    public void onMessage(Message message, byte[] pattern) {
        // 获取订阅的信息 这种解析消息方式防止乱码
        Object data = redisTemplate.getValueSerializer().deserialize(message.getBody());
        try {
            if (data == null) {
                return;
            }
            webSocketBroker.sendMessage(data);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}

消息发布

redisTemplate.convertAndSend(Constants.REDIS_CHANNEL, JSON.toJSONString(data));

websocket的broker配置

@ServerEndpoint("/websocket/{id}/")
@Component
public class WebSocketBroker implements Serializable {

    private static final long SECESSIONIST = 600000;

    private static final Logger LOG = LoggerFactory.getLogger(WebSocketBroker.class);

    private static final String WEB_SOCKET_KEY = "WEBSOCKET:";

    private static final String WEB_SOCKET_COUNT = "COUNT:TOTAL:";
    /**
     * 与某个客户端的连接会话,需要通过它来给客户端发送数据
     */
    private Session session;

    /**
     * 用来存放每个客户端对应的 WebSocketServer 对象
     */
    private static final ConcurrentHashMap<String, WebSocketBroker> WEB_SOCKET_BROKER_HASH_MAP =
            new ConcurrentHashMap<>();

    /**
     * 存放参数对象 设计业务逻辑,具体处理不展示
     */
    private static final ConcurrentHashMap<String, WebSocketVO> WEB_SOCKET_PARAM_HASH_MAP = new ConcurrentHashMap<>();

    /**
     * 用户 id
     */
    private Long id;

    private static RedisTemplate redisTemplate;

    private static StringRedisTemplate stringRedisTemplate;

    /**
     * 连接建立成功调用的方法
     */
    @OnOpen
    public void onOpen(Session session, @PathParam("id") Long id) {
        session.setMaxIdleTimeout(SECESSIONIST);
        this.session = session;
        this.id = id;
        redisTemplate = SpringUtils.getBean("showRedisTemplate");
        stringRedisTemplate = SpringUtils.getBean(StringRedisTemplate.class);
        Boolean flag = redisTemplate.opsForHash().hasKey(WEB_SOCKET_KEY + id, session.getId());
        if (!flag) {
            redisTemplate.opsForHash().put(WEB_SOCKET_KEY + id, session.getId(), id);
            // 存放session和用户关系
            stringRedisTemplate.opsForValue().set(session.getId(), id + "");
            stringRedisTemplate.opsForValue().increment(WEB_SOCKET_COUNT, 1);
            WEB_SOCKET_BROKER_HASH_MAP.put(session.getId(), this);
        }
        LOG.info("用户id:" + id + "连接,当前在线数为:" + stringRedisTemplate.opsForValue().get(WEB_SOCKET_COUNT));
    }

    /**
     * 连接关闭调用的方法
     */
    @OnClose
    public void onClose(Session session) {
        // 移出连接
        redisTemplate.opsForHash().delete(WEB_SOCKET_KEY + id, session.getId());
        stringRedisTemplate.delete(session.getId());
        WEB_SOCKET_BROKER_HASH_MAP.remove(session.getId());
        // 数量减少
        redisTemplate.opsForValue().increment(WEB_SOCKET_COUNT, -1);
        LOG.info("用户id:" + id + "退出,当前在线数为:" + stringRedisTemplate.opsForValue().get(WEB_SOCKET_COUNT));
    }

    /**
     * 收到客户端消息后调用的方法
     *
     * @param message 客户端发送过来的消息
     */
    @OnMessage
    public void onMessage(String message, Session session) {
        LOG.info("用户id消息:" + id + ",报文:" + message);
    }

    /**
     * 发生错误时调用
     *
     * @param session
     * @param error
     */
    @OnError
    public void onError(Session session, Throwable error) {
        LOG.error("用户id错误:" + this.id + ",原因:" + error.getMessage());
    }

    /**
     * 使用redis 去发布消息
     */
    public void sendMessage(String message, Long userId) throws IOException {
        // 或者这个用户下websocket对象;
        Set<String> sessionKeys = (Set<String>) redisTemplate.opsForHash().keys(WEB_SOCKET_KEY + userId);
        // 如果没有用户,选择返回
        if (CollectionUtils.isEmpty(sessionKeys)) {
            return;
        }
        for (String sessionId : sessionKeys) {
            WebSocketBroker webSocketBroker = WEB_SOCKET_BROKER_HASH_MAP.get(sessionId);
            LOG.info("用户 {} 发送的消息 {}", userId, message);
            webSocketBroker.sendMessageSingle(message);
        }
    }

    public Map<String, WebSocketVO> getAllParam() {
        return WEB_SOCKET_PARAM_HASH_MAP;
    }

    public WebSocketVO getParam(String sessionId) {
        return WEB_SOCKET_PARAM_HASH_MAP.get(sessionId);
    }

    public Map<String, WebSocketBroker> getAllWebSocket() {
        return WEB_SOCKET_BROKER_HASH_MAP;
    }

    public void sendMessageAll(String message) throws IOException {
        // 或者这个用户下websocket对象;
        Set<String> sessionKeys = (Set<String>) redisTemplate.opsForHash().keys(WEB_SOCKET_KEY + "*");
        // 如果没有用户,选择返回
        if (CollectionUtils.isEmpty(sessionKeys)) {
            return;
        }
        for (String sessionId : sessionKeys) {
            WebSocketBroker webSocketBroker = WEB_SOCKET_BROKER_HASH_MAP.get(sessionId);
            if (webSocketBroker != null) {
                webSocketBroker.sendMessageSingle(message);
            }

        }
    }

    /**
     * 消息推送
     */
    public void sendMessageSingle(String message) throws IOException {
        this.session.getBasicRemote().sendText(message);
    }
}