增加配置类 ClientEndpointConfig
package com.feng.im.demos.web;
import javax.websocket.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
final class MyClientEndpointConfig implements ClientEndpointConfig {
private final Map<String, Object> userProperties;
private final ClientEndpointConfig config;
MyClientEndpointConfig(Class<?> annotatedEndpointClass, Map<String, Object> userProperties) {
this.userProperties = userProperties;
ClientEndpointConfig.Builder builder = Builder.create();
builder.configurator(new Configurator());
ClientEndpoint annotation = annotatedEndpointClass.getAnnotation(ClientEndpoint.class);
config = builder.decoders(Arrays.asList(annotation.decoders())).encoders(Arrays.asList(annotation.encoders())).preferredSubprotocols(Arrays.asList(annotation.subprotocols())).build();
}
@Override
public List<String> getPreferredSubprotocols() {
return config.getPreferredSubprotocols();
}
@Override
public List<Extension> getExtensions() {
return config.getExtensions();
}
@Override
public Configurator getConfigurator() {
return config.getConfigurator();
}
@Override
public List<Class<? extends Encoder>> getEncoders() {
return config.getEncoders();
}
@Override
public List<Class<? extends Decoder>> getDecoders() {
return config.getDecoders();
}
@Override
public Map<String, Object> getUserProperties() {
return userProperties;
}
}
WebSocketClient
package com.feng.im.demos.web;
import com.feng.im.sign.SignUtils;
import com.feng.im.sign.bean.SignDTO;
import org.apache.tomcat.websocket.Constants;
import org.apache.tomcat.websocket.pojo.PojoEndpointClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509ExtendedTrustManager;
import javax.websocket.*;
import java.io.IOException;
import java.net.Socket;
import java.net.URI;
import java.nio.ByteBuffer;
import java.security.SecureRandom;
import java.security.cert.X509Certificate;
import java.util.Enumeration;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
@ClientEndpoint
public class WebSocketClient {
private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketClient.class);
private static final ConcurrentHashMap<String, Session> webSocketSet = new ConcurrentHashMap<>();
private static final ConcurrentHashMap<String, String> sessionMap = new ConcurrentHashMap<>();
private Session forwardSession;
@OnOpen
public void onOpen(Session session, EndpointConfig config) {
LOGGER.info("Connected to WebSocket server");
circlePing(session);
}
public void circlePing(Session session) {
new Thread(
() -> {
while (session.isOpen()) {
try {
TimeUnit.SECONDS.sleep(5);
final String id = session.getId();
if (webSocketSet.values().stream().anyMatch(s -> id.equals(s.getId()))) {
if (!sessionMap.containsKey(id)) {
break;
}
session.getBasicRemote().sendPing(ByteBuffer.allocate(0));
LOGGER.info("ping success...");
} else {
sessionMap.remove(id);
session.close();
}
} catch (IOException | InterruptedException e) {
LOGGER.error("ping ByteBuffer warn...", e);
}
}
LOGGER.warn("Ping循环关闭");
}).start();
}
@OnMessage
public void onMessage(String message, Session session) {
LOGGER.info("【websocket 接收消息】: {}", message);
final Enumeration<String> keys = webSocketSet.keys();
while (keys.hasMoreElements()) {
final String userId = keys.nextElement();
System.out.println("【websocket 发送消息】: " + userId);
}
}
@OnClose
public void onClose(Session session) throws InterruptedException {
LOGGER.info("【websocket 关闭】{}", session.getId());
String userId = sessionMap.get(session.getId());
if (userId == null) {
return;
}
closeSession(userId);
TimeUnit.SECONDS.sleep(5);
connect();
}
public void connect() {
WebSocketContainer container = ContainerProvider.getWebSocketContainer();
try {
URI uri = URI.create("wss://127.0.0.1/websocket");
Map<String, Object> userProperties = new ConcurrentHashMap<>();
SSLContext sc = SSLContext.getInstance("TLS");
sc.init(null, trustManagers(), new SecureRandom());
userProperties.put(Constants.SSL_CONTEXT_PROPERTY, sc);
final ClientEndpointConfig endpointConfig = new MyClientEndpointConfig(WebSocketClient.class, userProperties);
final PojoEndpointClient endpointClient = new PojoEndpointClient(this, endpointConfig.getDecoders());
Session session = container.connectToServer(endpointClient, endpointConfig, uri);
LOGGER.info("【websocket 连接成功】 {}", session);
this.forwardSession = session;
} catch (Exception e) {
LOGGER.error("【websocket 连接异常】", e);
TimeUnit.SECONDS.sleep(15);
connect();
}
}
public static TrustManager[] trustManagers() {
return new TrustManager[]{new X509ExtendedTrustManager() {
@Override
public void checkClientTrusted(X509Certificate[] x509Certificates, String s, Socket socket) {
}
@Override
public void checkServerTrusted(X509Certificate[] x509Certificates, String s, Socket socket) {
}
@Override
public void checkClientTrusted(X509Certificate[] x509Certificates, String s, SSLEngine sslEngine) {
}
@Override
public void checkServerTrusted(X509Certificate[] x509Certificates, String s, SSLEngine sslEngine) {
}
@Override
public X509Certificate[] getAcceptedIssuers() {
return null;
}
@Override
public void checkClientTrusted(X509Certificate[] arg0, String arg1) {
System.out.println("checkClientTrusted");
}
@Override
public void checkServerTrusted(X509Certificate[] arg0, String arg1) {
System.out.println("checkServerTrusted");
}
}};
}
/**
* 通过用户ID关闭session
*
* @param userId 用户id
*/
public static void closeSession(String userId) {
LOGGER.info("close Session:{}", userId);
if (null == userId) {
return;
}
final Session session = webSocketSet.remove(userId);
if (null == session) {
return;
}
try {
session.close();
LOGGER.info("【websocket 关闭成功】");
} catch (Exception e) {
LOGGER.error("【websocket 关闭异常】", e);
} finally {
sessionMap.remove(session.getId());
}
}
public Session getForwardSession() {
return forwardSession;
}
public void send(String message) throws IOException {
LOGGER.info("send:{}", message);
getForwardSession().getBasicRemote().sendText(message);
}
}
final PojoEndpointClient endpointClient = new PojoEndpointClient(this, endpointConfig.getDecoders());