WebSocket集群实现

453 阅读6分钟

概述

WebSocket是一种在单个TCP连接上进行全双工通信的协议。WebSocket通信协议于2011年被IETF定为标准RFC 6455,并由RFC7936补充规范。WebSocket API也被W3C定为标准。

WebSocket使得客户端和服务器之间的数据交换变得更加简单,允许服务端主动向客户端推送数据。在WebSocket API中,浏览器和服务器只需要完成一次握手,两者之间就直接可以创建持久性的连接,并进行双向数据传输。 [百度上搜索的 0.0]

目的

相信有很多小伙伴在实际工作中会接触到WebSocket,毕竟这玩意在交互上还是很友好滴~那么接下来由我tom老弟跟大家捋一捋如何实现WebSocket集群。

流程图

image.png 在此,推荐一下ProcessOn 在线画图网站,非常不错滴~~

准备

首先得先配置一下nginx.conf,具体配置如下:

stream {
  upstream lb {
      server 127.0.0.1:9092 weight=5 max_fails=3 fail_timeout=30s;
      server 127.0.0.1:9098 weight=5 max_fails=3 fail_timeout=30s;
  }
  server {
      listen 9000;
      proxy_pass lb;
  }
}

代码实现

前端html

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="utf-8">
    <meta http-equiv="X-UA-Compatible" content="IE=edge">
    <title>测试界面</title>
    <script src="https://code.jquery.com/jquery-3.2.1.min.js" ></script>
</head>

<body>

<div>
    <input type="text" style="width: 20%" value="ws://127.0.0.1:9000/websocket/message/10001" id="url">
    <button id="userId" value="10002" hidden="hidden"/>
	<button id="btn_join">连接</button>
	<button id="btn_exit">断开</button>
</div>
<br/>
<textarea id="message" cols="100" rows="9"></textarea> <button id="btn_send">发送消息</button>
<br/>
<br/>
<textarea id="text_content" readonly="readonly" cols="100" rows="9"></textarea>返回内容
<br/>
<br/>


<div>
    <input type="text" style="width: 20%" value="ws://127.0.0.1:9000/websocket/message/10002" id="url2">
    <button id="userId2" value="10001" hidden="hidden"/>
    <button id="btn_join2">连接</button>
    <button id="btn_exit2">断开</button>
</div>
<br/>
<textarea id="message2" cols="100" rows="9"></textarea> <button id="btn_send2">发送消息</button>
<br/>
<br/>
<textarea id="text_content2" readonly="readonly" cols="100" rows="9"></textarea>返回内容
<br/>
<br/>


<script type="text/javascript">
    $(document).ready(function(){
        var ws = null;
        // 连接
        $('#btn_join').click(function() {
        	var url = $("#url").val();
            ws = new WebSocket(url);
            ws.onopen = function(event) {
                $('#text_content').append('已经打开连接!' + '\n');
            }
            ws.onmessage = function(event) {
                $('#text_content').append(event.data + '\n');
                console.log(event.data);
            }
            ws.onclose = function(event) {
                $('#text_content').append('已经关闭连接!' + '\n');
            }
        });
        //心跳检查


        // 发送消息
        $('#btn_send').click(function() {
            var message = $('#message').val();
            //if (ws) {
            //    ws.send(message);
            //} else {
            //    alert("未连接到服务器");
            //}
            $.post('http://127.0.0.1:9000/webSocket/send',
                {
                    'userId':$('#userId').val(),
                    'message':message
                });
        });
        //断开
        $('#btn_exit').click(function() {
            if (ws) {
                ws.close();
                ws = null;
            }
        });

        //第二个用户操作
        send();
    })

    function send(){
        var ws = null;
        // 连接
        $('#btn_join2').click(function() {
            var url = $("#url2").val();
            ws = new WebSocket(url);
            ws.onopen = function(event) {
                $('#text_content2').append('已经打开连接!' + '\n');
            }
            ws.onmessage = function(event) {
                $('#text_content2').append(event.data + '\n');
                console.log(event.data);
            }
            ws.onclose = function(event) {
                $('#text_content2').append('已经关闭连接!' + '\n');
            }
        });
        // 发送消息
        $('#btn_send2').click(function() {
            var message = $('#message2').val();
            //if (ws) {
            //    ws.send(message);
            //} else {
            //    alert("未连接到服务器");
            //}
            $.post('http://127.0.0.1:9000/webSocket/send',
                {
                    'userId':$('#userId2').val(),
                    'message':message
                });
        });
        //断开
        $('#btn_exit2').click(function() {
            if (ws) {
                ws.close();
                ws = null;
            }
        });
    }
</script>
</body>
</html>

后端代码

package com.xxx.websocket;

import java.io.IOException;
import java.util.Collection;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import javax.websocket.Session;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * websocket 客户端用户集
 * 
 * @author 
 */
public class WebSocketUsers
{
    /**
     * WebSocketUsers 日志控制器
     */
    private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketUsers.class);

    /**
     * 用户集
     */
    private static Map<String, Session> USERS = new ConcurrentHashMap<String, Session>();

    /**
     * 存储用户
     *
     * @param key 唯一键
     * @param session 用户信息
     */
    public static void put(String key, Session session)
    {
        USERS.put(key, session);
    }

    /**
     * 移除用户
     *
     * @param session 用户信息
     *
     * @return 移除结果
     */
    public static boolean remove(Session session)
    {
        String key = null;
        boolean flag = USERS.containsValue(session);
        if (flag)
        {
            Set<Map.Entry<String, Session>> entries = USERS.entrySet();
            for (Map.Entry<String, Session> entry : entries)
            {
                Session value = entry.getValue();
                if (value.equals(session))
                {
                    key = entry.getKey();
                    break;
                }
            }
        }
        else
        {
            return true;
        }
        return remove(key);
    }

    /**
     * 移出用户
     *
     * @param key 键
     */
    public static boolean remove(String key)
    {
        LOGGER.info("\n 正在移出用户 - {}", key);
        Session remove = USERS.remove(key);
        if (remove != null)
        {
            boolean containsValue = USERS.containsValue(remove);
            LOGGER.info("\n 移出结果 - {}", containsValue ? "失败" : "成功");
            return containsValue;
        }
        else
        {
            return true;
        }
    }

    /**
     * 获取在线用户列表
     *
     * @return 返回用户集合
     */
    public static Map<String, Session> getUsers()
    {
        return USERS;
    }

    /**
     * 群发消息文本消息
     *
     * @param message 消息内容
     */
    public static void sendMessageToUsersByText(String message)
    {
        Collection<Session> values = USERS.values();
        for (Session value : values)
        {
            sendMessageToUserByText(value, message);
        }
    }

    /**
     * 发送文本消息
     *
     * @param session 自己的用户名
     * @param message 消息内容
     */
    public static void sendMessageToUserByText(Session session, String message)
    {
        if (session != null)
        {
            try
            {
                session.getBasicRemote().sendText(message);
            }
            catch (IOException e)
            {
                LOGGER.error("\n[发送消息异常]", e);
            }
        }
        else
        {
            LOGGER.info("\n[你已离线]");
        }
    }
}
package com.xxx.websocket;

import java.util.Objects;
import java.util.concurrent.Semaphore;
import javax.websocket.OnClose;
import javax.websocket.OnError;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

/**
 * websocket 消息处理
 * 
 * @author 
 */
@Component
@ServerEndpoint("/websocket/message/{userId}")
public class WebSocketServer
{
    /**
     * WebSocketServer 日志控制器
     */
    private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketServer.class);

    /**
     * 默认最多允许同时在线人数100
     */
    public static int socketMaxOnlineCount = 100;

    private static Semaphore socketSemaphore = new Semaphore(socketMaxOnlineCount);

    /**
     * redis 推送消息
     */
    @Autowired
    private RedisPubSub redisPubSub;

    /**
     * 连接建立成功调用的方法
     */
    @OnOpen
    public void onOpen(Session session,@PathParam(value = "userId") String userId) throws Exception
    {
        boolean semaphoreFlag = false;
        // 尝试获取信号量
        semaphoreFlag = SemaphoreUtils.tryAcquire(socketSemaphore);
        if (!semaphoreFlag)
        {
            // 未获取到信号量
            LOGGER.error("\n 当前在线人数超过限制数- {}", socketMaxOnlineCount);
            WebSocketUsers.sendMessageToUserByText(session, "当前在线人数超过限制数:" + socketMaxOnlineCount);
            session.close();
        }
        else
        {
            if(! Objects.equals(session,WebSocketUsers.getUsers().get(userId))){
                // 添加用户
                WebSocketUsers.put(userId, session);
                LOGGER.info("\n 建立连接 - {}", session);
                LOGGER.info("\n 当前人数 - {}", WebSocketUsers.getUsers().size());
                WebSocketUsers.sendMessageToUserByText(session, "连接成功");
            }
        }
    }

    /**
     * 连接关闭时处理
     */
    @OnClose
    public void onClose(Session session,@PathParam(value = "userId") String userId)
    {
        LOGGER.info("\n 关闭连接 - {}", session);
        // 移除用户
        WebSocketUsers.remove(userId);
        // 获取到信号量则需释放
        SemaphoreUtils.release(socketSemaphore);
    }

    /**
     * 抛出异常时处理
     */
    @OnError
    public void onError(Session session, Throwable exception) throws Exception
    {
        if (session.isOpen())
        {
            // 关闭连接
            session.close();
        }
        // todo sessionId 与 userId关联 防止连接失败没有将用户关闭
        String sessionId = session.getId();
        LOGGER.info("\n 连接异常 - {}", sessionId);
        LOGGER.info("\n 异常信息 - {}", exception);
        // 移出用户
        WebSocketUsers.remove(sessionId);
        // 获取到信号量则需释放
        SemaphoreUtils.release(socketSemaphore);
    }

    /**
     * 服务器接收到客户端消息时调用的方法
     */
    @OnMessage
    public void onMessage(String message, Session session)
    {
        LOGGER.info("message == >>> {}", message);
        //测试的时候使用,仅供参考
        WebSocketUsers.sendMessageToUserByText(session, message);
    }
}
package com.xxx.websocket;

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;

/**
 * websocket 配置
 * 
 * @author 
 */
@Configuration
public class WebSocketConfig
{
    @Bean
    public ServerEndpointExporter serverEndpointExporter()
    {
        return new ServerEndpointExporter();
    }
}
package com.xxx.websocket;

import lombok.Data;

/**
 * @Description TODO 发送消息实体
 * @Author tom
 * @Date 2021-08-18 15:26
 **/
@Data
public class UserRedis {

    private String userId;

    private String message;

}
package com.xxx.websocket;

import java.util.concurrent.Semaphore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * 信号量相关处理
 * 
 * @author 
 */
public class SemaphoreUtils
{

    /**
     * SemaphoreUtils 日志控制器
     */
    private static final Logger LOGGER = LoggerFactory.getLogger(SemaphoreUtils.class);

    /**
     * 获取信号量
     * 
     * @param semaphore
     * @return
     */
    public static boolean tryAcquire(Semaphore semaphore)
    {
        boolean flag = false;

        try
        {
            flag = semaphore.tryAcquire();
        }
        catch (Exception e)
        {
            LOGGER.error("获取信号量异常", e);
        }

        return flag;
    }

    /**
     * 释放信号量
     * 
     * @param semaphore
     */
    public static void release(Semaphore semaphore)
    {

        try
        {
            semaphore.release();
        }
        catch (Exception e)
        {
            LOGGER.error("释放信号量异常", e);
        }
    }
}
package com.xxx.websocket;

import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.listener.ChannelTopic;
import org.springframework.stereotype.Service;

/**
 * @Description TODO 推送消息
 * @Author tom
 * @Date 2021-08-18 15:00
 **/
@Slf4j
@Service
public class RedisPubSub {

    private static final Logger logger = LoggerFactory.getLogger(RedisPubSub.class);

    @Autowired
    private RedisTemplate redisTemplate;

    private ChannelTopic topic = new ChannelTopic("/redis/pubsub");

    /**
     * 推送消息
     *
     * @param userId 用户id
     */
    public void publish(String userId,String message) {
        log.info("向redis中推送消息");
        UserRedis userRedis = new UserRedis();
        userRedis.setUserId(userId);
        userRedis.setMessage(message);
        redisTemplate.convertAndSend(topic.getTopic(), userRedis);
    }

}
package com.xxx.websocket;

import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.listener.PatternTopic;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.data.redis.listener.adapter.MessageListenerAdapter;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;

/**
 * @Description TODO redis 监听配置类
 * @Author tom
 * @Date 2021-08-18 15:12
 **/
@Configuration
public class RedisListenerConfig {


    @Bean
    public MessageListenerAdapter listener(MessageSubscriber messageSubscriber) {
        Jackson2JsonRedisSerializer jackson2JsonRedisSerializer = new
                Jackson2JsonRedisSerializer(Object.class);
        ObjectMapper om = new ObjectMapper();
        om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        om.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
        jackson2JsonRedisSerializer.setObjectMapper(om);
        MessageListenerAdapter adapter = new MessageListenerAdapter(messageSubscriber, "onMessage");
        adapter.setSerializer(jackson2JsonRedisSerializer);
        adapter.afterPropertiesSet();
        return adapter;
    }

    /**
     * 将订阅器绑定到容器
     *
     * @param factory
     * @param messageListenerAdapter
     * @return
     */
    @Bean
    public RedisMessageListenerContainer container(RedisConnectionFactory factory,
                                                   MessageListenerAdapter messageListenerAdapter) {

        RedisMessageListenerContainer container = new RedisMessageListenerContainer();
        container.setConnectionFactory(factory);
        container.addMessageListener(messageListenerAdapter, new PatternTopic("/redis/pubsub"));
        return container;
    }

}
package com.xxx.websocket;

import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

import javax.websocket.Session;
import java.util.Objects;

/**
 * @Description TODO redis 订阅消息
 * @Author tom
 * @Date 2021-08-18 15:04
 **/
@Slf4j
@Component
public class MessageSubscriber {

    public void onMessage(UserRedis userRedis) {
        log.info("有消息来啦~~");
        //判断当前用户是否与该节点建立websocket连接
        Session session = WebSocketUsers.getUsers().get(userRedis.getUserId());
        if (! Objects.isNull(session)) {
            log.info("命中当前节点~~");
            WebSocketUsers.sendMessageToUserByText(session, userRedis.getMessage());
        }
    }
}
package com.xxx.lms.controller;

import com.xxx.common.api.R;
import com.xxx.websocket.RedisPubSub;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

/**
 * @Description TODO
 * @Author tom
 * @Date 2021-08-18 15:16
 **/
@RestController
@RequestMapping("/webSocket")
public class WebSocketController {

    @Autowired
    private RedisPubSub redisPubSub;

    @PostMapping("/send")
    public R send(String userId,String message){
        redisPubSub.publish(userId,message);
        return R.success("发送成功");
    }

}

以上就是核心代码 拿过去就能只能用

Q&A

1、网络波动时,如何保证一直连接。
心跳机制,需要在前端进行处理,实例代码中没有心跳机制,这个小伙伴可以自行实现。
2、如何保证连接信息不会被多个节点所持有。
在关闭网页时需要通知服务断开连接这样就不会导致多个节点持有同一个连接信息。
3、连接数过大,如何解决。
tom老弟的思路是通过连接复用的方式解决,这个以后再深究吧。。。
4、通过redis来实现消息推送,是不是不够优雅或效率上是否有缺陷
是的,我认为也是不够优雅。每次发送消息需要通知多个节点来完成,在效率上是会有缺陷,但是,考虑到session无法被序列化的问题,tom老弟暂时没有想到更好的解决方案。(其实有个思路,根据咱运维老哥-飞哥科普ip_hash可以做到用户每次连接都能指向同一个服务来提高效率问题,但是没有落地...)

最后

每天多学一点知识。提升自己,影响他人~~