WebSocketContainer 连接wss

120 阅读2分钟

增加配置类 ClientEndpointConfig

package com.feng.im.demos.web;

import javax.websocket.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

final class MyClientEndpointConfig implements ClientEndpointConfig {

    private final Map<String, Object> userProperties;
    private final ClientEndpointConfig config;

    MyClientEndpointConfig(Class<?> annotatedEndpointClass, Map<String, Object> userProperties) {
        this.userProperties = userProperties;
        ClientEndpointConfig.Builder builder = Builder.create();
        builder.configurator(new Configurator());
        ClientEndpoint annotation = annotatedEndpointClass.getAnnotation(ClientEndpoint.class);
        config = builder.decoders(Arrays.asList(annotation.decoders())).encoders(Arrays.asList(annotation.encoders())).preferredSubprotocols(Arrays.asList(annotation.subprotocols())).build();
    }

    @Override
    public List<String> getPreferredSubprotocols() {
        return config.getPreferredSubprotocols();
    }

    @Override
    public List<Extension> getExtensions() {
        return config.getExtensions();
    }

    @Override
    public Configurator getConfigurator() {
        return config.getConfigurator();
    }

    @Override
    public List<Class<? extends Encoder>> getEncoders() {
        return config.getEncoders();
    }


    @Override
    public List<Class<? extends Decoder>> getDecoders() {
        return config.getDecoders();
    }

    @Override
    public Map<String, Object> getUserProperties() {
        return userProperties;
    }

}

WebSocketClient

package com.feng.im.demos.web;

import com.feng.im.sign.SignUtils;
import com.feng.im.sign.bean.SignDTO;
import org.apache.tomcat.websocket.Constants;
import org.apache.tomcat.websocket.pojo.PojoEndpointClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509ExtendedTrustManager;
import javax.websocket.*;
import java.io.IOException;
import java.net.Socket;
import java.net.URI;
import java.nio.ByteBuffer;
import java.security.SecureRandom;
import java.security.cert.X509Certificate;
import java.util.Enumeration;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

@ClientEndpoint
public class WebSocketClient {

    private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketClient.class);
    private static final ConcurrentHashMap<String, Session> webSocketSet = new ConcurrentHashMap<>();
    private static final ConcurrentHashMap<String, String> sessionMap = new ConcurrentHashMap<>();
    private Session forwardSession;

    @OnOpen
    public void onOpen(Session session, EndpointConfig config) {
        LOGGER.info("Connected to WebSocket server");
        circlePing(session);
    }

    public void circlePing(Session session) {
        new Thread(
                () -> {
                    while (session.isOpen()) {
                        try {
                            TimeUnit.SECONDS.sleep(5);
                            final String id = session.getId();
                            if (webSocketSet.values().stream().anyMatch(s -> id.equals(s.getId()))) {
                                if (!sessionMap.containsKey(id)) {
                                    break;
                                }
                                session.getBasicRemote().sendPing(ByteBuffer.allocate(0));
                                LOGGER.info("ping success...");
                            } else {
                                sessionMap.remove(id);
                                session.close();
                            }
                        } catch (IOException | InterruptedException e) {
                            LOGGER.error("ping ByteBuffer warn...", e);
                        }
                    }
                    LOGGER.warn("Ping循环关闭");
                }).start();
    }

    @OnMessage
    public void onMessage(String message, Session session) {
        LOGGER.info("【websocket 接收消息】: {}", message);
        final Enumeration<String> keys = webSocketSet.keys();
        while (keys.hasMoreElements()) {
            final String userId = keys.nextElement();
            System.out.println("【websocket 发送消息】: " + userId);
        }
    }

    @OnClose
    public void onClose(Session session) throws InterruptedException {
        LOGGER.info("【websocket 关闭】{}", session.getId());
        String userId = sessionMap.get(session.getId());
        if (userId == null) {
            return;
        }
        closeSession(userId);
        TimeUnit.SECONDS.sleep(5);
        connect();
    }

    public void connect() {
        WebSocketContainer container = ContainerProvider.getWebSocketContainer();
        try {
            URI uri = URI.create("wss://127.0.0.1/websocket");
            Map<String, Object> userProperties = new ConcurrentHashMap<>();
            SSLContext sc = SSLContext.getInstance("TLS");
            sc.init(null, trustManagers(), new SecureRandom());
            userProperties.put(Constants.SSL_CONTEXT_PROPERTY, sc);
            final ClientEndpointConfig endpointConfig = new MyClientEndpointConfig(WebSocketClient.class, userProperties);
            final PojoEndpointClient endpointClient = new PojoEndpointClient(this, endpointConfig.getDecoders());
            Session session = container.connectToServer(endpointClient, endpointConfig, uri);
            LOGGER.info("【websocket 连接成功】 {}", session);
            this.forwardSession = session;
        } catch (Exception e) {
            LOGGER.error("【websocket 连接异常】", e);
            TimeUnit.SECONDS.sleep(15);
            connect();
        }
    }

    public static TrustManager[] trustManagers() {
        return new TrustManager[]{new X509ExtendedTrustManager() {
            @Override
            public void checkClientTrusted(X509Certificate[] x509Certificates, String s, Socket socket) {

            }

            @Override
            public void checkServerTrusted(X509Certificate[] x509Certificates, String s, Socket socket) {

            }

            @Override
            public void checkClientTrusted(X509Certificate[] x509Certificates, String s, SSLEngine sslEngine) {

            }

            @Override
            public void checkServerTrusted(X509Certificate[] x509Certificates, String s, SSLEngine sslEngine) {

            }

            @Override
            public X509Certificate[] getAcceptedIssuers() {
                return null;
            }

            @Override
            public void checkClientTrusted(X509Certificate[] arg0, String arg1) {
                System.out.println("checkClientTrusted");
            }

            @Override
            public void checkServerTrusted(X509Certificate[] arg0, String arg1) {
                System.out.println("checkServerTrusted");
            }
        }};
    }


    /**
     * 通过用户ID关闭session
     *
     * @param userId 用户id
     */
    public static void closeSession(String userId) {
        LOGGER.info("close Session:{}", userId);
        if (null == userId) {
            return;
        }
        final Session session = webSocketSet.remove(userId);
        if (null == session) {
            return;
        }
        try {
            session.close();
            LOGGER.info("【websocket 关闭成功】");
        } catch (Exception e) {
            LOGGER.error("【websocket 关闭异常】", e);
        } finally {
            sessionMap.remove(session.getId());
        }
    }

    public Session getForwardSession() {
        return forwardSession;
    }

    public void send(String message) throws IOException {
        LOGGER.info("send:{}", message);
        getForwardSession().getBasicRemote().sendText(message);
    }
}

final PojoEndpointClient endpointClient = new PojoEndpointClient(this, endpointConfig.getDecoders());