若依cloud集成websocket

799 阅读8分钟

基于ruoyi-cloud版本集成websocket

​ 考虑到若依框架并未内置WebSocket功能,而项目又需要实时向用户发送通知,我决定深入探索并实现WebSocket的集成。这一过程不仅能够满足项目的实际需求,还能为开发者提供一个详细的参考指南。以下是我将记录并分享的WebSocket集成过程,希望能为同样需求的开发者提供帮助。

一、后端部分

​ 最初,由于考虑不周,我将WebSocket功能直接放置在system服务下。然而,在后续开发中发现,其他模块也需要使用WebSocket功能。因此,我决定】将WebSocket功能独立出来,并建立一个API供其他服务调用。开发者可以根据自己的项目需求,选择将WebSocket功能放置在common包下,以实现更好的模块化和复用性。

我是直接放在system服务下,代码如下:大家ctrl + c 、 ctrl + v 即可!

WebSocketConfig

/**
 * 首先注入一个ServerEndpointExporterBean,该Bean会自动注册使用@ServerEndpoint注解申明的websocket endpoint
 */
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {

    private static final Logger log = LoggerFactory.getLogger(WebSocketConfig.class);

    @Autowired
    private MyWebSocketHandler myWebSocketHandler;

    @Autowired
    private WebSocketInterceptor webSocketInterceptor;

    @Value("#{'${websocket.wsHandlers}'.split(',')}")
    private String[] paths;

    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        registry.addHandler(myWebSocketHandler, paths)
                .setAllowedOrigins("*")
                .addInterceptors(webSocketInterceptor);
    }
}

WebSocketServer

@Component
@ServerEndpoint("/websocket/message")
public class WebSocketServer {
    /*========================声明类变量,意在所有实例共享=================================================*/
    /**
     * WebSocketServer 日志控制器
     */
    private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketServer.class);

    /**
     * 默认最多允许同时在线人数100
     */
    public static int socketMaxOnlineCount = 100;

    private static Semaphore socketSemaphore = new Semaphore(socketMaxOnlineCount);

    HashedWheelTimer timer = new HashedWheelTimer(1, TimeUnit.SECONDS, 8);
    /**
     * concurrent包的线程安全Set,用来存放每个客户端对应的MyWebSocket对象。
     */
    private static final CopyOnWriteArraySet<WebSocketServer> webSocketSet = new CopyOnWriteArraySet<>();
    /**
     * 连接数
     */
    private static final AtomicInteger count = new AtomicInteger();

    /*========================声明实例变量,意在每个实例独享=======================================================*/
    /**
     * 与某个客户端的连接会话,需要通过它来给客户端发送数据
     */
    private Session session;
    /**
     * 用户id
     */
    private String sid = "";

    /**
     * 连接建立成功调用的方法
     */
    @OnOpen
    public void onOpen(Session session) throws Exception {
        // 尝试获取信号量
        boolean semaphoreFlag = SemaphoreUtils.tryAcquire(socketSemaphore);
        if (!semaphoreFlag) {
            // 未获取到信号量
            LOGGER.error("\n 当前在线人数超过限制数- {}", socketMaxOnlineCount);
            // 给当前Session 登录用户发送消息
            sendMessageToUserByText(session, "当前在线人数超过限制数:" + socketMaxOnlineCount);
            session.close();
        } else {
            // 返回此会话的经过身份验证的用户,如果此会话没有经过身份验证的用户,则返回null
//            Authentication authentication = (Authentication) session.getUserPrincipal();
//            SecurityUtils.setAuthentication(authentication);
//            String username = SecurityUtils.getUsername();
//            this.session = session;
//            //如果存在就先删除一个,防止重复推送消息
//            for (WebSocketServer webSocket : webSocketSet) {
//                if (webSocket.sid.equals(username)) {
//                    webSocketSet.remove(webSocket);
//                    count.getAndDecrement();
//                }
//            }
//            count.getAndIncrement();
//            webSocketSet.add(this);
//            this.sid = username;
//            LOGGER.info("\n 当前人数 - {}", count);
//            sendMessageToUserByText(session, "连接成功");
        }
    }

    /**
     * 连接关闭时处理
     */
    @OnClose
    public void onClose(Session session) {
        LOGGER.info("\n 关闭连接 - {}", session);
        // 移除用户
        webSocketSet.remove(session);
        // 获取到信号量则需释放
        SemaphoreUtils.release(socketSemaphore);
    }

    /**
     * 抛出异常时处理
     */
    @OnError
    public void onError(Session session, Throwable exception) throws Exception {
        if (session.isOpen()) {
            // 关闭连接
            session.close();
        }
        String sessionId = session.getId();
        LOGGER.info("\n 连接异常 - {}", sessionId);
        LOGGER.info("\n 异常信息 - {}", exception);
        // 移出用户
        webSocketSet.remove(session);
        // 获取到信号量则需释放
        SemaphoreUtils.release(socketSemaphore);
    }

    /**
     * 服务器接收到客户端消息时调用的方法
     */
    @OnMessage
    public void onMessage(String message, Session session) {
//        Authentication authentication = (Authentication) session.getUserPrincipal();
//        LOGGER.info("收到来自" + sid + "的信息:" + message);
//        // 实时更新
//        this.refresh(sid, authentication);
//        sendMessageToUserByText(session, "我收到了你的新消息哦");
    }

//    /**
//     * 刷新定时任务,发送信息
//     */
//    private void refresh(String userId, Authentication authentication) {
//        this.start(5000L, task -> {
//            // 判断用户是否在线,不在线则不用处理,因为在内部无法关闭该定时任务,所以通过返回值在外部进行判断。
//            if (WebSocketServer.isConn(userId)) {
//                // 因为这里是长链接,不会和普通网页一样,每次发送http 请求可以走拦截器【doFilterInternal】续约,所以需要手动续约
//                SecurityUtils.setAuthentication(authentication);
//                // 从数据库或者缓存中获取信息,构建自定义的Bean
//                DeviceInfo deviceInfo = DeviceInfo.builder().Macaddress("de5a735951ee").Imei("351517175516665")
//                        .Battery("99").Charge("0").Latitude("116.402649").Latitude("39.914859").Altitude("80")
//                        .Method(SecurityUtils.getUsername()).build();
//                // TODO判断数据是否有更新
//                // 发送最新数据给前端
//                WebSocketServer.sendInfo("JSON", deviceInfo, userId);
//                // 设置返回值,判断是否需要继续执行
//                return true;
//            }
//            return false;
//        });
//    }

//    private void start(long delay, Function<Timeout, Boolean> function) {
//        timer.newTimeout(t -> {
//            // 获取返回值,判断是否执行
//            Boolean result = function.apply(t);
//            if (result) {
//                timer.newTimeout(t.task(), delay, TimeUnit.MILLISECONDS);
//            }
//        }, delay, TimeUnit.MILLISECONDS);
//    }

    /**
     * 判断是否有链接
     *
     * @return
     */
    public static boolean isConn(String sid) {
        for (WebSocketServer item : webSocketSet) {
            if (item.sid.equals(sid)) {
                return true;
            }
        }
        return false;
    }

    /**
     * 群发自定义消息
     * 或者指定用户发送消息
     */
    public static void sendInfo(String type, Object data, @PathParam("sid") String sid) {
        // 遍历WebSocketServer对象集合,如果符合条件就推送
        for (WebSocketServer item : webSocketSet) {
            try {
                //这里可以设定只推送给这个sid的,为null则全部推送
                if (sid == null) {
                    item.sendMessage(type, data);
                } else if (item.sid.equals(sid)) {
                    item.sendMessage(type, data);
                }
            } catch (IOException ignored) {
            }
        }
    }

    /**
     * 实现服务器主动推送
     */
    private void sendMessage(String type, Object data) throws IOException {
        Map<String, Object> result = new HashMap<>();
        result.put("type", type);
        result.put("data", data);
        this.session.getAsyncRemote().sendText(JSON.toJSONString(result));
    }

    /**
     * 实现服务器主动推送-根据session
     */
    public static void sendMessageToUserByText(Session session, String message) {
        if (session != null) {
            try {
                session.getBasicRemote().sendText(message);
            } catch (IOException e) {
                LOGGER.error("\n[发送消息异常]", e);
            }
        } else {
            LOGGER.info("\n[你已离线]");
        }
    }
}

WebSocketInterceptor

@Component
public class WebSocketInterceptor implements HandshakeInterceptor {

    private static final Logger log = LoggerFactory.getLogger(WebSocketInterceptor.class);

    //在握手之前执行该方法, 继续握手返回true, 中断握手返回false. 通过attributes参数设置WebSocketSession的属性
    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes){
        if (request instanceof ServletServerHttpRequest) {
            String uri = request.getURI().getPath();
            String token = uri.substring(uri.lastIndexOf("/")+1);
            attributes.put(WebSocketConstant.TOKEN,token);
            log.info("current token is:"+token);
        }
        return true;
    }
    @Override
    public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
        log.info("coming webSocketInterceptor afterHandshake method...");
    }
}

MyWebSocketHandler

@Component
public class MyWebSocketHandler implements WebSocketHandler {

    @Autowired
    private RedisTemplate redisTemplate;

    /**
     * 日志
     */
    private static final Logger log = LoggerFactory.getLogger(MyWebSocketHandler.class);

    /**
     * 连接建立后调用
     *
     * @param session
     */
    @Override
    public void afterConnectionEstablished(@NonNull WebSocketSession session) {
        Long userId = WebSocketCommon.getUserId(session);
        if(ObjectUtil.isNull(userId)){
            return;
        }
        log.info("连接websocket成功==={}", userId);
        WebSocketSession oldSession = WebSocketCommon.CLIENT_SESSION.get(WebSocketCommon.getUserId(session));
        if (ObjectUtil.isNotNull(oldSession)) {
            log.info("关闭原始会话开始==={}", userId);
            try {
                oldSession.close();
            } catch (IOException e) {
                log.info("关闭原始会话失败==={}", userId);
            }
        }
        //新的会话放入
        WebSocketCommon.CLIENT_SESSION.put(userId, session);
        //放入redis

    }

    /**
     * 接收客户端发送的消息-用作客户端心跳
     *
     * @param session
     * @param message
     */
    @Override
    public void handleMessage(@NonNull WebSocketSession session, @NonNull WebSocketMessage<?> message) {
        log.info("处理消息启动");
        try {
            //获取客户端发送的消息
            String mess = (String) message.getPayload();
            Long userId = WebSocketCommon.getUserId(session);
            log.info("发送人:{},消息内容:{}", userId, mess);
            //这边可能需要处理更新map里session机制,防止map里面保存的失效,待定,等后面实际运行观察
            if (session.isOpen()) {
                //心跳响应包
                session.sendMessage(new PongMessage());
            }
        } catch (Exception e) {
            log.error("e", e);
        }
    }

    /**
     * 处理错误
     *
     * @param session
     * @param exception
     * @throws Exception
     */
    @Override
    public void handleTransportError(@NonNull WebSocketSession session, @NonNull Throwable exception) throws Exception {
        log.info("处理消息启动");
        if (session.isOpen()) {
            session.close();
        }
        log.error("连接错误", exception);
        Object userid = session.getAttributes().get(WebSocketConstant.TOKEN).toString();
        if (userid == null) {
            return;
        }
        Long userId = Objects.requireNonNull(WebSocketCommon.getUserId(session));
        WebSocketCommon.CLIENT_SESSION.remove(userId);
    }

    /**
     * 连接关闭后调用
     *
     * @param session
     * @param closeStatus
     */
    @Override
    public void afterConnectionClosed(@NonNull WebSocketSession session, @NonNull CloseStatus closeStatus) {
        log.error("在线人数: {}", WebSocketCommon.CLIENT_SESSION.size());
        log.error("连接已关闭: " + closeStatus);
        WebSocketCommon.CLIENT_SESSION.remove(Objects.requireNonNull(WebSocketCommon.getUserId(session)));
        log.error("在线人数: {}", WebSocketCommon.CLIENT_SESSION.size());
    }

    /**
     * 连接建立后调用
     *
     * @return
     */
    @Override
    public boolean supportsPartialMessages() {
        return false;
    }
}

WebSocketCommon

public class WebSocketCommon {

    /**
     * 保存已登录的会话session
     */
    public static ConcurrentHashMap<Long, WebSocketSession> CLIENT_SESSION = new ConcurrentHashMap<Long, WebSocketSession>();

    /**
     * 获取用户id
     *
     * @param session
     * @return
     */
    public static Long getUserId(WebSocketSession session) {
        Object userIdObj = session.getAttributes().get(WebSocketConstant.TOKEN);
        if (ObjectUtil.isNull(userIdObj)) {
            return null;
        }
        String userIdStr = userIdObj.toString();
        if (StrUtil.isBlank(userIdStr) || "undefined".equals(userIdStr)) {
            return null;
        }
        return Long.parseLong(userIdStr);
    }

    /**
     * 获取指定用户的WebSocketSession
     *
     * @param userId 用户ID
     * @return WebSocketSession 或 null(如果未找到)
     */
    public static WebSocketSession getSessionByUserId(Long userId) {
        return CLIENT_SESSION.get(userId);
    }
}

WebSocketConstant

public class WebSocketConstant {
    /**
     * 用户标识
     */
    public static String CLIENT_FLAG = "clientId";

    /**
     * 用户标识key
     */
    public static String TOKEN = "token";

    /**
     * 每个连接key前缀标识
     */
    public static String PREFIX = "prefix";
}

WebSocketUsers

public class WebSocketUsers {
    /**
     * WebSocketUsers 日志控制器
     */
    private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketUsers.class);

    /**
     * 用户集
     */
    private static Map<String, Session> USERS = new ConcurrentHashMap<String, Session>();

    /**
     * 存储用户
     *
     * @param key     唯一键
     * @param session 用户信息
     */
    public static void put(String key, Session session) {
        USERS.put(key, session);
    }

    /**
     * 移除用户
     *
     * @param session 用户信息
     * @return 移除结果
     */
    public static boolean remove(Session session) {
        String key = null;
        boolean flag = USERS.containsValue(session);
        if (flag) {
            Set<Map.Entry<String, Session>> entries = USERS.entrySet();
            for (Map.Entry<String, Session> entry : entries) {
                Session value = entry.getValue();
                if (value.equals(session)) {
                    key = entry.getKey();
                    break;
                }
            }
        } else {
            return true;
        }
        return remove(key);
    }

    /**
     * 移出用户
     *
     * @param key 键
     */
    public static boolean remove(String key) {
        LOGGER.info("\n 正在移出用户 - {}", key);
        Session remove = USERS.remove(key);
        if (remove != null) {
            boolean containsValue = USERS.containsValue(remove);
            LOGGER.info("\n 移出结果 - {}", containsValue ? "失败" : "成功");
            return containsValue;
        } else {
            return true;
        }
    }

    /**
     * 获取在线用户列表
     *
     * @return 返回用户集合
     */
    public static Map<String, Session> getUsers() {
        return USERS;
    }

    /**
     * 群发消息文本消息
     *
     * @param message 消息内容
     */
    public static void sendMessageToUsersByText(String message) {
        Collection<Session> values = USERS.values();
        for (Session value : values) {
            sendMessageToUserByText(value, message);
        }
    }

    /**
     * 发送文本消息
     *
     * @param session 缓存
     * @param message 消息内容
     */
    public static void sendMessageToUserByText(Session session, String message) {
        if (session != null) {
            try {
                session.getBasicRemote().sendText(message);
            } catch (IOException e) {
                LOGGER.error("\n[发送消息异常]", e);
            }
        } else {
            LOGGER.info("\n[你已离线]");
        }
    }
}

SemaphoreUtils

public class SemaphoreUtils{
    /**
     * SemaphoreUtils 日志控制器
     */
    private static final Logger LOGGER = LoggerFactory.getLogger(SemaphoreUtils.class);

    /**
     * 获取信号量
     *
     * @param semaphore
     * @return
     */
    public static boolean tryAcquire(Semaphore semaphore)
    {
        boolean flag = false;
        try
        {
            flag = semaphore.tryAcquire();
        }
        catch (Exception e)
        {
            LOGGER.error("获取信号量异常", e);
        }
        return flag;
    }

    /**
     * 释放信号量
     *
     * @param semaphore
     */
    public static void release(Semaphore semaphore)
    {
        try
        {
            semaphore.release();
        }
        catch (Exception e)
        {
            LOGGER.error("释放信号量异常", e);
        }
    }
}

yaml配置里新增

# websocket链接配置
websocket:
  wsHandlers: /ws/{token}

修改一下AuthFilter 过滤

// AuthFilter filter 代码
String token = getToken(request);
if (StringUtils.isEmpty(token)) {
   return unauthorizedResponse(exchange, "令牌不能为空");
}
Claims claims = JwtUtils.parseToken(token);
if (claims == null) {
   return unauthorizedResponse(exchange, "令牌已过期或验证不正确!");
}
String userkey = JwtUtils.getUserKey(claims);
boolean islogin = redisService.hasKey(getTokenKey(userkey));
if (!islogin) {
   return unauthorizedResponse(exchange, "登录状态已过期");
}
String userid = JwtUtils.getUserId(claims);
String username = JwtUtils.getUserName(claims);
if (StringUtils.isEmpty(userid) || StringUtils.isEmpty(username)) {
   return unauthorizedResponse(exchange, "令牌验证失败");
}
    /**
     * 获取缓存key
     */
    private String getTokenKey(String token) {
        return CacheConstants.LOGIN_TOKEN_KEY + token;
    }

    /**
     * 获取请求token
     */
    private String getToken(ServerHttpRequest request) {
        String token = request.getHeaders().getFirst(TokenConstants.AUTHENTICATION);
        if (StringUtils.isEmpty(token)) {
            //尝试从拼接参数中获取token,这步是为了websocket鉴权
            List<?> authorization = request.getQueryParams().get(TokenConstants.AUTHENTICATION);
            if (authorization != null) {
                token = authorization.get(0).toString();
            }
        }
        // 如果前端设置了令牌前缀,则裁剪掉前缀
        if (token != null && StringUtils.isNotEmpty(token) && token.startsWith(TokenConstants.PREFIX)) {
            token = token.replaceFirst(TokenConstants.PREFIX, StringUtils.EMPTY);
        }
        return token;
    }

然后我们测试连接,连接成功!

![1729762959257](C:\Users\LENOVO\Documents\WeChat Files\wxid_02gnk7iazrm522\FileStorage\Temp\1729762959257.png)

二、前端部分

我们封装一个工具类即可.

import {getToken} from "@/utils/auth";

let socketUrl = ''; // socket地址
let websocket = null; // websocket 实例
let heartTime = null; // 心跳定时器实例
let socketHeart = 0; // 心跳次数
const HeartTimeOut = 10000; // 心跳超时时间 10000 = 10s
let socketError = 0; // 错误次数

// 初始化socket
export const initWebSocket = (url, userId) => {
  socketUrl = url;
  // 初始化 websocket
  websocket = new WebSocket(url + userId + '?Authorization=Bearer ' + getToken());
  websocketonopen();
  websocketonmessage();
  websocketonerror();
  websocketclose();
  sendSocketHeart();
  return websocket;
};

// socket 连接成功
export const websocketonopen = () => {
  websocket.onopen = function () {
    console.log('%cHaoXin: Websocket连接成功', 'color: green; font-weight: bold; font-size: 16px;');
    resetHeart();
  };
};

// socket 连接失败
export const websocketonerror = () => {
  websocket.onerror = function (e) {
    console.log('%cHaoXin: Websocket连接失败', 'color: red; font-weight: bold; font-size: 16px;');
  };
};

// socket 断开链接
export const websocketclose = () => {
  websocket.onclose = function (e) {
    console.log('%cHaoXin: Websocket断开连接', 'color: red; font-weight: bold; font-size: 16px;');
  };
};

// socket 重置心跳
export const resetHeart = () => {
  socketHeart = 0;
  socketError = 0;
  clearInterval(heartTime);
  sendSocketHeart();
};

// socket心跳发送
export const sendSocketHeart = () => {
  heartTime = setInterval(() => {
    // 如果连接正常则发送心跳
    if (websocket.readyState == 1) {
      websocket.send(
        JSON.stringify({
          type: 'ping'
        })
      );
      socketHeart = socketHeart + 1;
    } else {
      // 重连
      reconnect();
    }
  }, HeartTimeOut);
};

// socket重连
export const reconnect = () => {
  if (socketError <= 2) {
    clearInterval(heartTime);
    initWebSocket(socketUrl);
    socketError = socketError + 1;
    console.log('%cHaoXin: socket重连', 'color: green; font-weight: bold; font-size: 16px;');
  } else {
    console.log('%cHaoXin: 重试次数已用完', 'color: red; font-weight: bold; font-size: 16px;');
    clearInterval(heartTime);
  }
};

// socket 发送数据
export const sendMsg = (data) => {
  websocket.send(data);
};

// socket 接收数据
export const websocketonmessage = () => {
  websocket.onmessage = function (e) {
    if (e.data.indexOf('heartbeat') > 0) {
      resetHeart();
    }
    if (e.data.indexOf('ping') > 0) {
      return;
    }
    console.log(e.data);
    return e.data;
  };
};