基于ruoyi-cloud版本集成websocket
考虑到若依框架并未内置WebSocket功能,而项目又需要实时向用户发送通知,我决定深入探索并实现WebSocket的集成。这一过程不仅能够满足项目的实际需求,还能为开发者提供一个详细的参考指南。以下是我将记录并分享的WebSocket集成过程,希望能为同样需求的开发者提供帮助。
一、后端部分
最初,由于考虑不周,我将WebSocket功能直接放置在system服务下。然而,在后续开发中发现,其他模块也需要使用WebSocket功能。因此,我决定】将WebSocket功能独立出来,并建立一个API供其他服务调用。开发者可以根据自己的项目需求,选择将WebSocket功能放置在common包下,以实现更好的模块化和复用性。
我是直接放在system服务下,代码如下:大家ctrl + c 、 ctrl + v 即可!
WebSocketConfig
/**
* 首先注入一个ServerEndpointExporterBean,该Bean会自动注册使用@ServerEndpoint注解申明的websocket endpoint
*/
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
private static final Logger log = LoggerFactory.getLogger(WebSocketConfig.class);
@Autowired
private MyWebSocketHandler myWebSocketHandler;
@Autowired
private WebSocketInterceptor webSocketInterceptor;
@Value("#{'${websocket.wsHandlers}'.split(',')}")
private String[] paths;
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(myWebSocketHandler, paths)
.setAllowedOrigins("*")
.addInterceptors(webSocketInterceptor);
}
}
WebSocketServer
@Component
@ServerEndpoint("/websocket/message")
public class WebSocketServer {
/*========================声明类变量,意在所有实例共享=================================================*/
/**
* WebSocketServer 日志控制器
*/
private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketServer.class);
/**
* 默认最多允许同时在线人数100
*/
public static int socketMaxOnlineCount = 100;
private static Semaphore socketSemaphore = new Semaphore(socketMaxOnlineCount);
HashedWheelTimer timer = new HashedWheelTimer(1, TimeUnit.SECONDS, 8);
/**
* concurrent包的线程安全Set,用来存放每个客户端对应的MyWebSocket对象。
*/
private static final CopyOnWriteArraySet<WebSocketServer> webSocketSet = new CopyOnWriteArraySet<>();
/**
* 连接数
*/
private static final AtomicInteger count = new AtomicInteger();
/*========================声明实例变量,意在每个实例独享=======================================================*/
/**
* 与某个客户端的连接会话,需要通过它来给客户端发送数据
*/
private Session session;
/**
* 用户id
*/
private String sid = "";
/**
* 连接建立成功调用的方法
*/
@OnOpen
public void onOpen(Session session) throws Exception {
// 尝试获取信号量
boolean semaphoreFlag = SemaphoreUtils.tryAcquire(socketSemaphore);
if (!semaphoreFlag) {
// 未获取到信号量
LOGGER.error("\n 当前在线人数超过限制数- {}", socketMaxOnlineCount);
// 给当前Session 登录用户发送消息
sendMessageToUserByText(session, "当前在线人数超过限制数:" + socketMaxOnlineCount);
session.close();
} else {
// 返回此会话的经过身份验证的用户,如果此会话没有经过身份验证的用户,则返回null
// Authentication authentication = (Authentication) session.getUserPrincipal();
// SecurityUtils.setAuthentication(authentication);
// String username = SecurityUtils.getUsername();
// this.session = session;
// //如果存在就先删除一个,防止重复推送消息
// for (WebSocketServer webSocket : webSocketSet) {
// if (webSocket.sid.equals(username)) {
// webSocketSet.remove(webSocket);
// count.getAndDecrement();
// }
// }
// count.getAndIncrement();
// webSocketSet.add(this);
// this.sid = username;
// LOGGER.info("\n 当前人数 - {}", count);
// sendMessageToUserByText(session, "连接成功");
}
}
/**
* 连接关闭时处理
*/
@OnClose
public void onClose(Session session) {
LOGGER.info("\n 关闭连接 - {}", session);
// 移除用户
webSocketSet.remove(session);
// 获取到信号量则需释放
SemaphoreUtils.release(socketSemaphore);
}
/**
* 抛出异常时处理
*/
@OnError
public void onError(Session session, Throwable exception) throws Exception {
if (session.isOpen()) {
// 关闭连接
session.close();
}
String sessionId = session.getId();
LOGGER.info("\n 连接异常 - {}", sessionId);
LOGGER.info("\n 异常信息 - {}", exception);
// 移出用户
webSocketSet.remove(session);
// 获取到信号量则需释放
SemaphoreUtils.release(socketSemaphore);
}
/**
* 服务器接收到客户端消息时调用的方法
*/
@OnMessage
public void onMessage(String message, Session session) {
// Authentication authentication = (Authentication) session.getUserPrincipal();
// LOGGER.info("收到来自" + sid + "的信息:" + message);
// // 实时更新
// this.refresh(sid, authentication);
// sendMessageToUserByText(session, "我收到了你的新消息哦");
}
// /**
// * 刷新定时任务,发送信息
// */
// private void refresh(String userId, Authentication authentication) {
// this.start(5000L, task -> {
// // 判断用户是否在线,不在线则不用处理,因为在内部无法关闭该定时任务,所以通过返回值在外部进行判断。
// if (WebSocketServer.isConn(userId)) {
// // 因为这里是长链接,不会和普通网页一样,每次发送http 请求可以走拦截器【doFilterInternal】续约,所以需要手动续约
// SecurityUtils.setAuthentication(authentication);
// // 从数据库或者缓存中获取信息,构建自定义的Bean
// DeviceInfo deviceInfo = DeviceInfo.builder().Macaddress("de5a735951ee").Imei("351517175516665")
// .Battery("99").Charge("0").Latitude("116.402649").Latitude("39.914859").Altitude("80")
// .Method(SecurityUtils.getUsername()).build();
// // TODO判断数据是否有更新
// // 发送最新数据给前端
// WebSocketServer.sendInfo("JSON", deviceInfo, userId);
// // 设置返回值,判断是否需要继续执行
// return true;
// }
// return false;
// });
// }
// private void start(long delay, Function<Timeout, Boolean> function) {
// timer.newTimeout(t -> {
// // 获取返回值,判断是否执行
// Boolean result = function.apply(t);
// if (result) {
// timer.newTimeout(t.task(), delay, TimeUnit.MILLISECONDS);
// }
// }, delay, TimeUnit.MILLISECONDS);
// }
/**
* 判断是否有链接
*
* @return
*/
public static boolean isConn(String sid) {
for (WebSocketServer item : webSocketSet) {
if (item.sid.equals(sid)) {
return true;
}
}
return false;
}
/**
* 群发自定义消息
* 或者指定用户发送消息
*/
public static void sendInfo(String type, Object data, @PathParam("sid") String sid) {
// 遍历WebSocketServer对象集合,如果符合条件就推送
for (WebSocketServer item : webSocketSet) {
try {
//这里可以设定只推送给这个sid的,为null则全部推送
if (sid == null) {
item.sendMessage(type, data);
} else if (item.sid.equals(sid)) {
item.sendMessage(type, data);
}
} catch (IOException ignored) {
}
}
}
/**
* 实现服务器主动推送
*/
private void sendMessage(String type, Object data) throws IOException {
Map<String, Object> result = new HashMap<>();
result.put("type", type);
result.put("data", data);
this.session.getAsyncRemote().sendText(JSON.toJSONString(result));
}
/**
* 实现服务器主动推送-根据session
*/
public static void sendMessageToUserByText(Session session, String message) {
if (session != null) {
try {
session.getBasicRemote().sendText(message);
} catch (IOException e) {
LOGGER.error("\n[发送消息异常]", e);
}
} else {
LOGGER.info("\n[你已离线]");
}
}
}
WebSocketInterceptor
@Component
public class WebSocketInterceptor implements HandshakeInterceptor {
private static final Logger log = LoggerFactory.getLogger(WebSocketInterceptor.class);
//在握手之前执行该方法, 继续握手返回true, 中断握手返回false. 通过attributes参数设置WebSocketSession的属性
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes){
if (request instanceof ServletServerHttpRequest) {
String uri = request.getURI().getPath();
String token = uri.substring(uri.lastIndexOf("/")+1);
attributes.put(WebSocketConstant.TOKEN,token);
log.info("current token is:"+token);
}
return true;
}
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
log.info("coming webSocketInterceptor afterHandshake method...");
}
}
MyWebSocketHandler
@Component
public class MyWebSocketHandler implements WebSocketHandler {
@Autowired
private RedisTemplate redisTemplate;
/**
* 日志
*/
private static final Logger log = LoggerFactory.getLogger(MyWebSocketHandler.class);
/**
* 连接建立后调用
*
* @param session
*/
@Override
public void afterConnectionEstablished(@NonNull WebSocketSession session) {
Long userId = WebSocketCommon.getUserId(session);
if(ObjectUtil.isNull(userId)){
return;
}
log.info("连接websocket成功==={}", userId);
WebSocketSession oldSession = WebSocketCommon.CLIENT_SESSION.get(WebSocketCommon.getUserId(session));
if (ObjectUtil.isNotNull(oldSession)) {
log.info("关闭原始会话开始==={}", userId);
try {
oldSession.close();
} catch (IOException e) {
log.info("关闭原始会话失败==={}", userId);
}
}
//新的会话放入
WebSocketCommon.CLIENT_SESSION.put(userId, session);
//放入redis
}
/**
* 接收客户端发送的消息-用作客户端心跳
*
* @param session
* @param message
*/
@Override
public void handleMessage(@NonNull WebSocketSession session, @NonNull WebSocketMessage<?> message) {
log.info("处理消息启动");
try {
//获取客户端发送的消息
String mess = (String) message.getPayload();
Long userId = WebSocketCommon.getUserId(session);
log.info("发送人:{},消息内容:{}", userId, mess);
//这边可能需要处理更新map里session机制,防止map里面保存的失效,待定,等后面实际运行观察
if (session.isOpen()) {
//心跳响应包
session.sendMessage(new PongMessage());
}
} catch (Exception e) {
log.error("e", e);
}
}
/**
* 处理错误
*
* @param session
* @param exception
* @throws Exception
*/
@Override
public void handleTransportError(@NonNull WebSocketSession session, @NonNull Throwable exception) throws Exception {
log.info("处理消息启动");
if (session.isOpen()) {
session.close();
}
log.error("连接错误", exception);
Object userid = session.getAttributes().get(WebSocketConstant.TOKEN).toString();
if (userid == null) {
return;
}
Long userId = Objects.requireNonNull(WebSocketCommon.getUserId(session));
WebSocketCommon.CLIENT_SESSION.remove(userId);
}
/**
* 连接关闭后调用
*
* @param session
* @param closeStatus
*/
@Override
public void afterConnectionClosed(@NonNull WebSocketSession session, @NonNull CloseStatus closeStatus) {
log.error("在线人数: {}", WebSocketCommon.CLIENT_SESSION.size());
log.error("连接已关闭: " + closeStatus);
WebSocketCommon.CLIENT_SESSION.remove(Objects.requireNonNull(WebSocketCommon.getUserId(session)));
log.error("在线人数: {}", WebSocketCommon.CLIENT_SESSION.size());
}
/**
* 连接建立后调用
*
* @return
*/
@Override
public boolean supportsPartialMessages() {
return false;
}
}
WebSocketCommon
public class WebSocketCommon {
/**
* 保存已登录的会话session
*/
public static ConcurrentHashMap<Long, WebSocketSession> CLIENT_SESSION = new ConcurrentHashMap<Long, WebSocketSession>();
/**
* 获取用户id
*
* @param session
* @return
*/
public static Long getUserId(WebSocketSession session) {
Object userIdObj = session.getAttributes().get(WebSocketConstant.TOKEN);
if (ObjectUtil.isNull(userIdObj)) {
return null;
}
String userIdStr = userIdObj.toString();
if (StrUtil.isBlank(userIdStr) || "undefined".equals(userIdStr)) {
return null;
}
return Long.parseLong(userIdStr);
}
/**
* 获取指定用户的WebSocketSession
*
* @param userId 用户ID
* @return WebSocketSession 或 null(如果未找到)
*/
public static WebSocketSession getSessionByUserId(Long userId) {
return CLIENT_SESSION.get(userId);
}
}
WebSocketConstant
public class WebSocketConstant {
/**
* 用户标识
*/
public static String CLIENT_FLAG = "clientId";
/**
* 用户标识key
*/
public static String TOKEN = "token";
/**
* 每个连接key前缀标识
*/
public static String PREFIX = "prefix";
}
WebSocketUsers
public class WebSocketUsers {
/**
* WebSocketUsers 日志控制器
*/
private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketUsers.class);
/**
* 用户集
*/
private static Map<String, Session> USERS = new ConcurrentHashMap<String, Session>();
/**
* 存储用户
*
* @param key 唯一键
* @param session 用户信息
*/
public static void put(String key, Session session) {
USERS.put(key, session);
}
/**
* 移除用户
*
* @param session 用户信息
* @return 移除结果
*/
public static boolean remove(Session session) {
String key = null;
boolean flag = USERS.containsValue(session);
if (flag) {
Set<Map.Entry<String, Session>> entries = USERS.entrySet();
for (Map.Entry<String, Session> entry : entries) {
Session value = entry.getValue();
if (value.equals(session)) {
key = entry.getKey();
break;
}
}
} else {
return true;
}
return remove(key);
}
/**
* 移出用户
*
* @param key 键
*/
public static boolean remove(String key) {
LOGGER.info("\n 正在移出用户 - {}", key);
Session remove = USERS.remove(key);
if (remove != null) {
boolean containsValue = USERS.containsValue(remove);
LOGGER.info("\n 移出结果 - {}", containsValue ? "失败" : "成功");
return containsValue;
} else {
return true;
}
}
/**
* 获取在线用户列表
*
* @return 返回用户集合
*/
public static Map<String, Session> getUsers() {
return USERS;
}
/**
* 群发消息文本消息
*
* @param message 消息内容
*/
public static void sendMessageToUsersByText(String message) {
Collection<Session> values = USERS.values();
for (Session value : values) {
sendMessageToUserByText(value, message);
}
}
/**
* 发送文本消息
*
* @param session 缓存
* @param message 消息内容
*/
public static void sendMessageToUserByText(Session session, String message) {
if (session != null) {
try {
session.getBasicRemote().sendText(message);
} catch (IOException e) {
LOGGER.error("\n[发送消息异常]", e);
}
} else {
LOGGER.info("\n[你已离线]");
}
}
}
SemaphoreUtils
public class SemaphoreUtils{
/**
* SemaphoreUtils 日志控制器
*/
private static final Logger LOGGER = LoggerFactory.getLogger(SemaphoreUtils.class);
/**
* 获取信号量
*
* @param semaphore
* @return
*/
public static boolean tryAcquire(Semaphore semaphore)
{
boolean flag = false;
try
{
flag = semaphore.tryAcquire();
}
catch (Exception e)
{
LOGGER.error("获取信号量异常", e);
}
return flag;
}
/**
* 释放信号量
*
* @param semaphore
*/
public static void release(Semaphore semaphore)
{
try
{
semaphore.release();
}
catch (Exception e)
{
LOGGER.error("释放信号量异常", e);
}
}
}
yaml配置里新增
# websocket链接配置
websocket:
wsHandlers: /ws/{token}
修改一下AuthFilter 过滤
// AuthFilter filter 代码
String token = getToken(request);
if (StringUtils.isEmpty(token)) {
return unauthorizedResponse(exchange, "令牌不能为空");
}
Claims claims = JwtUtils.parseToken(token);
if (claims == null) {
return unauthorizedResponse(exchange, "令牌已过期或验证不正确!");
}
String userkey = JwtUtils.getUserKey(claims);
boolean islogin = redisService.hasKey(getTokenKey(userkey));
if (!islogin) {
return unauthorizedResponse(exchange, "登录状态已过期");
}
String userid = JwtUtils.getUserId(claims);
String username = JwtUtils.getUserName(claims);
if (StringUtils.isEmpty(userid) || StringUtils.isEmpty(username)) {
return unauthorizedResponse(exchange, "令牌验证失败");
}
/**
* 获取缓存key
*/
private String getTokenKey(String token) {
return CacheConstants.LOGIN_TOKEN_KEY + token;
}
/**
* 获取请求token
*/
private String getToken(ServerHttpRequest request) {
String token = request.getHeaders().getFirst(TokenConstants.AUTHENTICATION);
if (StringUtils.isEmpty(token)) {
//尝试从拼接参数中获取token,这步是为了websocket鉴权
List<?> authorization = request.getQueryParams().get(TokenConstants.AUTHENTICATION);
if (authorization != null) {
token = authorization.get(0).toString();
}
}
// 如果前端设置了令牌前缀,则裁剪掉前缀
if (token != null && StringUtils.isNotEmpty(token) && token.startsWith(TokenConstants.PREFIX)) {
token = token.replaceFirst(TokenConstants.PREFIX, StringUtils.EMPTY);
}
return token;
}
然后我们测试连接,连接成功!

二、前端部分
我们封装一个工具类即可.
import {getToken} from "@/utils/auth";
let socketUrl = ''; // socket地址
let websocket = null; // websocket 实例
let heartTime = null; // 心跳定时器实例
let socketHeart = 0; // 心跳次数
const HeartTimeOut = 10000; // 心跳超时时间 10000 = 10s
let socketError = 0; // 错误次数
// 初始化socket
export const initWebSocket = (url, userId) => {
socketUrl = url;
// 初始化 websocket
websocket = new WebSocket(url + userId + '?Authorization=Bearer ' + getToken());
websocketonopen();
websocketonmessage();
websocketonerror();
websocketclose();
sendSocketHeart();
return websocket;
};
// socket 连接成功
export const websocketonopen = () => {
websocket.onopen = function () {
console.log('%cHaoXin: Websocket连接成功', 'color: green; font-weight: bold; font-size: 16px;');
resetHeart();
};
};
// socket 连接失败
export const websocketonerror = () => {
websocket.onerror = function (e) {
console.log('%cHaoXin: Websocket连接失败', 'color: red; font-weight: bold; font-size: 16px;');
};
};
// socket 断开链接
export const websocketclose = () => {
websocket.onclose = function (e) {
console.log('%cHaoXin: Websocket断开连接', 'color: red; font-weight: bold; font-size: 16px;');
};
};
// socket 重置心跳
export const resetHeart = () => {
socketHeart = 0;
socketError = 0;
clearInterval(heartTime);
sendSocketHeart();
};
// socket心跳发送
export const sendSocketHeart = () => {
heartTime = setInterval(() => {
// 如果连接正常则发送心跳
if (websocket.readyState == 1) {
websocket.send(
JSON.stringify({
type: 'ping'
})
);
socketHeart = socketHeart + 1;
} else {
// 重连
reconnect();
}
}, HeartTimeOut);
};
// socket重连
export const reconnect = () => {
if (socketError <= 2) {
clearInterval(heartTime);
initWebSocket(socketUrl);
socketError = socketError + 1;
console.log('%cHaoXin: socket重连', 'color: green; font-weight: bold; font-size: 16px;');
} else {
console.log('%cHaoXin: 重试次数已用完', 'color: red; font-weight: bold; font-size: 16px;');
clearInterval(heartTime);
}
};
// socket 发送数据
export const sendMsg = (data) => {
websocket.send(data);
};
// socket 接收数据
export const websocketonmessage = () => {
websocket.onmessage = function (e) {
if (e.data.indexOf('heartbeat') > 0) {
resetHeart();
}
if (e.data.indexOf('ping') > 0) {
return;
}
console.log(e.data);
return e.data;
};
};