Spring WebSocket 实时消息推送

770 阅读2分钟

1. 前言

心有猛虎细嗅蔷薇

工作之余复习一下WebSocket,如果错误敬请指正。

2. 工程结构

3. Maven导入

spring 版本建议选择更高的

<dependency>    
    <groupId>org.springframework</groupId>    
    <artifactId>spring-websocket</artifactId>    
    <version>4.0.0.RELEASE</version>
</dependency>

4. WebSocket配置

package com.*.websocket.config;

import com.*.websocket.handler.MyWebSocketHandler;
import com.*.websocket.interceptor.WebSocketInterceptor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;

/**
 * Created by @author fuxj on 2020-4-1 13:33
 */
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {

    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        //springwebsocket 4.1.5版本前默认支持跨域访问,之后的版本默认不支持跨域,需要设置.setAllowedOrigins("*")
        registry.addHandler(myWebSocketHandler(), "/websocket").addInterceptors(new WebSocketInterceptor());
    }

    @Bean
    public WebSocketHandler myWebSocketHandler() {
        return new MyWebSocketHandler();
    }
}


5. Handler

package com.*.websocket.handler;

import lombok.extern.slf4j.Slf4j;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

/**
 * Created by @author fuxj on 2020-4-1 13:36
 */
@Slf4j
public class MyWebSocketHandler extends TextWebSocketHandler {

    /**
     * 在线用户列表
     */
    private static final Map<String, WebSocketSession> users;
    /**
     * 用户标识
     */
    private static final String CLIENT_ID = "username";

    static {
        users = new HashMap<>();
    }

    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        log.info("成功建立连接");
        String username = getClientId(session);
        if (username != null) {
            users.put(username, session);
            // 初次连接发送消息
            session.sendMessage(new TextMessage("成功建立socket连接"));
        }
    }

    @Override
    public void handleTextMessage(WebSocketSession session, TextMessage message) {
        WebSocketMessage message1 = new TextMessage("server:" + message);
        try {
            session.sendMessage(message1);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    /**
     * 发送信息给指定用户
     *
     * @param clientId
     * @param message
     * @return
     */
    public boolean sendToUsername(String clientId, TextMessage message) {
        // 如果不存在 不发送
        if (users.get(clientId) == null) {
            return false;
        }
        WebSocketSession session = users.get(clientId);
        // 连接未打开不发送
        if (!session.isOpen()) {
            return false;
        }
        try {
            session.sendMessage(message);
        } catch (IOException e) {
            return false;
        }
        return true;
    }

    /**
     * 发送给所有用户
     *
     * @param message
     * @return
     */
    public boolean sendToAll(TextMessage message) {
        boolean allSendSuccess = true;
        Set<String> clientIds = users.keySet();
        WebSocketSession session = null;
        for (String clientId : clientIds) {
            try {
                session = users.get(clientId);
                if (session.isOpen()) {
                    session.sendMessage(message);
                }
            } catch (IOException e) {
                e.printStackTrace();
                allSendSuccess = false;
            }
        }
        return allSendSuccess;
    }


    @Override
    public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
        if (session.isOpen()) {
            session.close();
        }
        users.remove(getClientId(session));
    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
        users.remove(getClientId(session));
    }

    @Override
    public boolean supportsPartialMessages() {
        return false;
    }

    /**
     * 获取用户标识
     *
     * @param session
     * @return
     */
    private String getClientId(WebSocketSession session) {
        try {
            String clientId = (String) session.getHandshakeAttributes().get(CLIENT_ID);
            return clientId;
        } catch (Exception e) {
            return null;
        }
    }
}

6. Interceptor

package com.*.websocket.interceptor;

import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;

import javax.servlet.http.HttpSession;
import java.util.Map;

/**
 * Created by @author fuxj on 2020-4-1 13:36
 */
@Component
public class WebSocketInterceptor implements HandshakeInterceptor {

    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler webSocketHandler, Map<String, Object> map) throws Exception {
        if (request instanceof ServletServerHttpRequest) {
            ServletServerHttpRequest serverHttpRequest = (ServletServerHttpRequest) request;
            HttpSession session = serverHttpRequest.getServletRequest().getSession();
            if (session != null) {
                map.put("username", session.getAttribute("username"));
            }
        }
        return true;
    }

    @Override
    public void afterHandshake(ServerHttpRequest serverHttpRequest, ServerHttpResponse serverHttpResponse, WebSocketHandler webSocketHandler, Exception e) {

    }
}


7. js调用

发送消息使用 ws.send(''); 
 function createSocket() {
        if ('WebSocket' in window) {
            var ws = new WebSocket("ws://localhost:9002/websocket");
            ws.onmessage = function (msg) {
                alert(msg.data);
            }
            ws.onerror = function (e) {
                alert(e);
            }
        } else {
            alert('浏览器不支持WebSocket');
        }
    }

8. 注意事项

WebSocket不支持IE10以下版本,兼容请自行百度