【Spring boot】spring-boot-starter-websocket方式:后端实时推送数据到前端

2,821 阅读5分钟

1 添加依赖

<dependency>
   <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-websocket</artifactId>
</dependency>

2 新建配置类

方式一:

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;

@Configuration
public class WebSocketConfig {

    @Bean
    public ServerEndpointExporter serverEndpointExporter() {
        return new ServerEndpointExporter();
    }
}
@Slf4j
@ServerEndpoint("/websocket/{sid}")
@Component
public class WebSocketServer {

    private static int onlineCount = 0;
    private static CopyOnWriteArraySet<WebSocketServer> webSocketSet = new CopyOnWriteArraySet<WebSocketServer>();
    private Session session;
    private String sid = "";

    private static DataService dataService;

    /**
     * webSocket并非单例,所以无法使用Autowired直接注入
     * 可以使用static的set注入
     *
     * @param dataService
     */
    @Autowired
    public void setDataService(DataService dataService) {
        WebSocketServer.dataService = dataService;
    }


    /**
     * 连接建立成功调用的方法
     */
    @OnOpen
    public void onOpen(Session session, @PathParam("sid") String sid) {
        this.session = session;
        ## 提高超时时间:在阻塞模式下发送WebSocket消息时,同一时间段只能有一个线程进行访问。首先程序通过tryAcquire尝试在规定的时间内获取一个许可,如果获取不到(有其他线程阻塞住了,未及时释放信号量。就会发生异常
        ## 比如:服务端向a客户端推送数据时a异常关闭链接,此时就会被阻塞住),tomcat官方文档中有说明,默认的超时时间是20秒,可以通过org.apache.tomcat.websocket.BLOCKING_SEND_TIMEOUT属性修改默认的超时时间。
        session.getUserProperties().put( "org.apache.tomcat.websocket.BLOCKING_SEND_TIMEOUT",25000L);
        webSocketSet.add(this);
        addOnlineCount();
        log.info("new socket is open,listen : sid = " + sid + ",onlineCount = " + getOnlineCount());
        this.sid = sid;
        try {
            sendMessage("connection is succeed");
        } catch (IOException e) {
            log.error("websocket IOException {} ", e);
        }
        listenerDate();
    }

    /**
     * 连接关闭调用的方法
     */
    @OnClose
    public void onClose() {
        webSocketSet.remove(this);
        subOnlineCount();
        log.info("a socket is close,onlineCount =  " + getOnlineCount());
    }

    /**
     * 收到客户端消息后调用的方法
     *
     * @param message
     */
    @OnMessage
    public void onMessage(String message, Session session) {
        log.info("received message from sid = " + sid + ",message = " + message);
        for (WebSocketServer item : webSocketSet) {
            try {
                item.sendMessage(message);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    /**
     * 发生异常时方法
     * @param session
     * @param error
     */
    @OnError
    public void onError(Session session, Throwable error) {
        log.error("{}", error);
        error.printStackTrace();
    }

    /**
     * 实现服务器主动推送
     */
    public void sendMessage(String message) throws IOException {
        log.info("send message to socket,sid = " + sid + ",message = " + message);
        this.session.getBasicRemote().sendText(message);
    }


    /**
     * 群发自定义消息
     */
    public static void sendInfo(String message, @PathParam("sid") String sid) throws IOException {
        for (WebSocketServer item : webSocketSet) {
            try {
                if (sid == null) {
                    item.sendMessage(message);
                } else if (item.sid.equals(sid)) {
                    item.sendMessage(message);
                }
            } catch (IOException e) {
                continue;
            }
        }
    }

    public static synchronized int getOnlineCount() {
        return onlineCount;
    }

    public static synchronized void addOnlineCount() {
        WebSocketServer.onlineCount++;
    }

    public static synchronized void subOnlineCount() {
        WebSocketServer.onlineCount--;
    }

    public synchronized void listenerDate() {
        Thread thread = new Thread(new Runnable() {
            @SneakyThrows
            @Override
            public void run() {
                while (true) {
                    String data = dataService.findData();
                    if (StringUtils.isNotBlank(data)) {
                        for (WebSocketServer item : webSocketSet) {
                            try {
                                item.sendMessage(data);
                            } catch (IOException e) {
                                e.printStackTrace();
                            }
                        }
                    }
                    Thread.sleep(5000);
                }

            }
        });
        thread.start();
    }

方式二:可用拦截器进行权限检验

@Configuration("webSocketconfig")
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {

    @Autowired
    private WebSocketHandler webSocketHandler;
    @Autowired
    private WebSocketAuthInterceptor webSocketAuthInterceptor;

    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        registry.addHandler(webSocketHandler, "/websocket","/websocket1")
                .addInterceptors(webSocketAuthInterceptor)
                .setAllowedOrigins("*");
    }
}
@Slf4j
@Component
public class WebSocketAuthInterceptor implements HandshakeInterceptor {

    /**
     * 握手前
     *
     * @param request
     * @param response
     * @param wsHandler
     * @param attributes
     * @return
     * @throws Exception
     */
    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
        log.info("socket before handshake:");
        ##  获得请求参数
        List<String> uid = request.getHeaders().get("token");
        if (uid != null) {
            ##  放入属性域,区别是哪个断点,可以分别对不同请求做个性化处理
            attributes.put("token_"+request.getURI().getPath(), uid);
            return true;
        }
        log.info("socket before handshake:no authrorization,access denied");
        return true;
    }

    /**
     * 握手后
     *
     * @param request
     * @param response
     * @param wsHandler
     * @param exception
     */
    @Override
    public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
        log.info("socket handshake success");
    }

}
@Slf4j
@Component
public class WebSocketHandler extends TextWebSocketHandler {
    private volatile Object lock = new Object();
    private static Map<String, WebSocketSession> SESSION_POOL = new ConcurrentHashMap<>();

    @Autowired
    DataService dataService;
    
    /**
     * socket 建立成功事件
     *
     * @param session
     * @throws Exception
     */
    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        Object token = session.getAttributes().get("token");
        if (token != null) {
            add(token.toString(), session);
            listenerDate();
        } else {
            throw new RuntimeException("user no authrorization,access denied");
        }
    }

    /**
     * 接收消息事件
     *
     * @param session
     * @param message
     * @throws Exception
     */
    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
        // 获得客户端传来的消息
        String payload = message.getPayload();
        Object token = session.getAttributes().get("token");
        log.info("server received message from token = " + token + ",message = " + payload);
        session.sendMessage(new TextMessage("server 发送给 " + token + " 消息 " + payload + " " + LocalDateTime.now().toString()));
    }

    /**
     * socket 断开连接时
     *
     * @param session
     * @param status
     * @throws Exception
     */
    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
        Object token = session.getAttributes().get("token");
        if (token != null) {
            // 用户退出,移除缓存
            remove(token.toString());
            log.info("client socket is closed");
        }
    }

    public void listenerDate() {
        Thread thread = new Thread(new Runnable() {
            @SneakyThrows
            @Override
            public void run() {
                while (true) {
                    UserData data = dataService.findData("2020-03-03", "2020-08-06");
                    if (null != data) {
                        Set<Map.Entry<String, WebSocketSession>> keySetView = SESSION_POOL.entrySet();
                        for (Map.Entry<String, WebSocketSession> entry : keySetView) {
                            synchronized (lock) {
                                WebSocketSession socketSession = entry.getValue();
                                socketSession.sendMessage(new TextMessage(JSONObject.toJSONString(data)));
                            }
                        }
                    }
                    Thread.sleep(5000);
                }
            }
        });
        thread.start();
    }


    public static void add(String key, WebSocketSession session) {
        SESSION_POOL.put(key, session);
    }


    public static WebSocketSession remove(String key) {
        return SESSION_POOL.remove(key);
    }

    public static void removeAndClose(String key) {
        WebSocketSession session = remove(key);
        if (session != null) {
            try {
                // 关闭连接
                session.close();
            } catch (IOException e) {
                // todo: 关闭出现异常处理
                e.printStackTrace();
            }
        }
    }

    public static WebSocketSession get(String key) {
        // 获得 session
        return SESSION_POOL.get(key);
    }

}

注意:webSocket设置上并非单例,所以无法直接使用Autowired直接注入,故采用set static方法注入

3 前端代码

<!DOCTYPE HTML>
<html>
<head>
    <title>My WebSocket</title>
</head>

<body>
Welcome<br/>
<input id="text" type="text" /><button onclick="send()">Send</button>    <button onclick="closeWebSocket()">Close</button>
<div id="message">
</div>
</body>

<script type="text/javascript">
    var websocket = null;

    //判断当前浏览器是否支持WebSocket
    if('WebSocket' in window){
        websocket = new WebSocket("ws://localhost:8533/data/websocket/111");
    }
    else{
        alert('Not support websocket')
    }

    //连接发生错误的回调方法
    websocket.onerror = function(){
        setMessageInnerHTML("error");
    };

    //连接成功建立的回调方法
    websocket.onopen = function(event){
        setMessageInnerHTML("open");
    }

    //接收到消息的回调方法
    websocket.onmessage = function(event){
        setMessageInnerHTML(event.data);
    }

    //连接关闭的回调方法
    websocket.onclose = function(){
	    alert("close");
        setMessageInnerHTML("close");
    }

    //监听窗口关闭事件,当窗口关闭时,主动去关闭websocket连接,防止连接还没断开就关闭窗口,server端会抛异常。
    window.onbeforeunload = function(){
	alert("close load");
        websocket.close();
    }

    //将消息显示在网页上
    function setMessageInnerHTML(innerHTML){
        document.getElementById('message').innerHTML += innerHTML + '<br/>';
    }

    //关闭连接
    function closeWebSocket(){
        websocket.close();
    }

    //发送消息
    function send(){
        var message = document.getElementById('text').value;
        websocket.send(message);
    }
</script>
</html>

4 nginx配置

由于前端使用的ws协议,所以通过域名访问时,nginx需要支持ws协议

location /data {
   proxy_pass http://localhost:8533/data/;    #设定代理服务器的协议和地址 
   proxy_buffers 256 4k;
   proxy_max_temp_file_size 0k;
   proxy_connect_timeout 30;
   proxy_send_timeout 60;
   proxy_read_timeout 60;
   ....
   proxy_http_version 1.1;                    ##主要是几个配置支持ws://协议
   proxy_set_header Upgrade $http_upgrade;    ##
   proxy_set_header Connection "upgrade";     ##
   
}

5 问题

Exception in thread "Thread-41" java.lang.IllegalStateException: The remote endpoint was in state [TEXT_PARTIAL_WRITING] which is an invalid state for called method
	at org.apache.tomcat.websocket.WsRemoteEndpointImplBase$StateMachine.checkState(WsRemoteEndpointImplBase.java:1234)
	at org.apache.tomcat.websocket.WsRemoteEndpointImplBase$StateMachine.textPartialStart(WsRemoteEndpointImplBase.java:1191)
	at org.apache.tomcat.websocket.WsRemoteEndpointImplBase.sendPartialString(WsRemoteEndpointImplBase.java:222)
	at org.apache.tomcat.websocket.WsRemoteEndpointBasic.sendText(WsRemoteEndpointBasic.java:49)
	at org.springframework.web.socket.adapter.standard.StandardWebSocketSession.sendTextMessage(StandardWebSocketSession.java:215)
	at org.springframework.web.socket.adapter.AbstractWebSocketSession.sendMessage(AbstractWebSocketSession.java:106)
	at com.test.data.config.webSocket.HttpAuthHandler$1.run(HttpAuthHandler.java:101)
	at java.lang.Thread.run(Thread.java:748)

由于多线程对同一个session进行发送消息导致

  1. 利用队列,将要发送的消息session放入队列中,遍历发送
  2. 同步方法解决,再方法上加锁

参考: www.cnblogs.com/bianzy/p/58…