Springboot整合WebSocket

479 阅读2分钟

最近在做一个给公司大屏上飘弹幕的功能,弹幕里是人的照片 + 一句话,比如:【照片】欢迎某总前来画饼!热烈欢迎!!!

因为要实时更新,就想到用WebSocket了,简单来个demo复习一下。

1.什么是WebSocket?

WebSocket是一种在单个TCP连接上进行全双工通信的协议,浏览器和服务器只需要完成一次握手,就直接可以创建持久性的连接,进行双向数据传输。

2.它可以做什么?

场景很多,比如:发通告、聊天、待办待阅,还有我们要做的弹幕等等。

4.demo

  • maven依赖
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
  • 配置类
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;
​
@Configuration
public class WebSocketConfig {
    
    /**
     *  ServerEndpointExporter会自动注册使用了@ServerEndpoint声明的Websocket endpoint
     */
    @Bean
    public ServerEndpointExporter serverEndpointExporter() {
        return new ServerEndpointExporter();
    }
    
}
​
  • 操作类
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArraySet;
import javax.websocket.OnClose;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import org.springframework.stereotype.Component;
import lombok.extern.slf4j.Slf4j;
​
@Component
@Slf4j
@ServerEndpoint("/websocket/{userId}") 
public class WebSocket {
    
    //客户端会话
    private Session session;
    private String userId;
    
    private static CopyOnWriteArraySet<WebSocket> webSockets =new CopyOnWriteArraySet<>();
    private static ConcurrentHashMap<String,Session> sessionPool = new ConcurrentHashMap<String,Session>();
    
    /**
     * 连接成功
     */
    @OnOpen
    public void onOpen(Session session, @PathParam(value="userId")String userId) {
        try {
            this.session = session;
            this.userId = userId;
            webSockets.add(this);
            sessionPool.put(userId, session);
            log.info("【websocket消息】有新的连接,总数为:"+webSockets.size());
        } catch (Exception e) {
        }
    }
    
    /**
     * 连接关闭
     */
    @OnClose
    public void onClose() {
        try {
            webSockets.remove(this);
            sessionPool.remove(this.userId);
            log.info("【websocket消息】连接断开,总数为:"+webSockets.size());
        } catch (Exception e) {
        }
    }
    
    /**
     * 收到客户端消息
     */
    @OnMessage
    public void onMessage(String message) {
        log.info("【websocket消息】收到客户端消息:"+message);
    }
    
    /*
     * 发送错误
     */
    @OnError
    public void onError(Session session, Throwable error) {
        log.error("用户错误,原因:"+error.getMessage());
        error.printStackTrace();
    }
​
    /*
     * 广播
     */
    public void sendAllMessage(String message) {
        log.info("【websocket消息】广播消息:"+message);
        for(WebSocket webSocket : webSockets) {
            try {
                if(webSocket.session.isOpen()) {
                    webSocket.session.getAsyncRemote().sendText(message);
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }
    
    /*
     * 单独发送
     */
    public void sendOneMessage(String userId, String message) {
        Session session = sessionPool.get(userId);
        if (session != null&&session.isOpen()) {
            try {
                log.info("【websocket消息】 单点消息:"+message);
                session.getAsyncRemote().sendText(message);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }
    
    /*
     * 发送多人
     */
    public void sendMoreMessage(String[] userIds, String message) {
        for(String userId:userIds) {
            Session session = sessionPool.get(userId);
            if (session != null&&session.isOpen()) {
                try {
                    log.info("【websocket消息】 单点消息:"+message);
                    session.getAsyncRemote().sendText(message);
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }
    }
    
}
​

5.测试

可以用ApiPost这个工具来测试,给服务端发送消息试试。要测试服务端往客户端推送消息的话可以用接口或者定时任务。

image-20230608184544711.png

image-20230608184333814.png

image-20230608184620018.png

6.问题
  IllegalStateException: The remote endpoint was in state [TEXT_FULL_WRITING] which is an invalid state for called method

原因同一个session发送消息产生冲突(就是说,同一时刻,多个线程向一个socket写数据发生冲突了),就会出现 TEXT_FULL_WRITING 异常,可以使用 getBasicRemote() 取代 getAsyncRemote() ,并对业务报错处加锁,确保数据传输的同步。

  synchronized(this){
      session.getBasicRemote().sendText(message);
  }