高性能 Netty WebSocket 工具类优化方案

60 阅读3分钟
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.*;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollServerSocketChannel;
import io.netty.channel.epoll.EpollSocketChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.websocketx.*;
import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketClientCompressionHandler;
import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketServerCompressionHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.handler.timeout.IdleStateHandler;
import io.netty.util.AttributeKey;
import io.netty.util.concurrent.DefaultEventExecutorGroup;
import io.netty.util.concurrent.EventExecutorGroup;

import java.net.URI;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

/**
 * 高性能 Netty WebSocket 工具类
 * 支持高并发、多人在线和高性能需求
 */
public class WebSocketUtil {

    // 连接元数据属性键
    public static final AttributeKey<String> CONNECTION_ID = AttributeKey.valueOf("connectionId");
    public static final AttributeKey<String> USER_ID = AttributeKey.valueOf("userId");
    
    // 连接管理器
    private static final Map<String, Channel> CONNECTION_MAP = new ConcurrentHashMap<>(1024);
    private static final Map<String, String> USER_CONNECTION_MAP = new ConcurrentHashMap<>(1024);
    
    // 线程池配置
    private static final int BOSS_THREADS = Math.max(1, Runtime.getRuntime().availableProcessors() / 2);
    private static final int WORKER_THREADS = Runtime.getRuntime().availableProcessors() * 2;
    
    // 业务处理线程池
    private static final EventExecutorGroup businessGroup = new DefaultEventExecutorGroup(
            WORKER_THREADS * 2, r -> new Thread(r, "WebSocket-Business-" + r.hashCode()));
    
    // 心跳配置(秒)
    private static final int READ_IDLE_TIME = 30;
    private static final int WRITE_IDLE_TIME = 25;
    private static final int ALL_IDLE_TIME = 0;
    
    // 消息聚合大小限制(字节)
    private static final int MAX_CONTENT_LENGTH = 65536;
    
    /**
     * 启动高性能 WebSocket 服务器
     */
    public static void startServer(int port, WebSocketMessageHandler messageHandler) {
        EventLoopGroup bossGroup;
        EventLoopGroup workerGroup;
        
        // 根据操作系统选择最优的事件循环
        if (isLinux()) {
            bossGroup = new EpollEventLoopGroup(BOSS_THREADS);
            workerGroup = new EpollEventLoopGroup(WORKER_THREADS);
        } else {
            bossGroup = new NioEventLoopGroup(BOSS_THREADS);
            workerGroup = new NioEventLoopGroup(WORKER_THREADS);
        }
        
        try {
            ServerBootstrap b = new ServerBootstrap();
            b.group(bossGroup, workerGroup)
             .channel(isLinux() ? EpollServerSocketChannel.class : NioServerSocketChannel.class)
             .option(ChannelOption.SO_BACKLOG, 1024)
             .option(ChannelOption.SO_REUSEADDR, true)
             .childOption(ChannelOption.SO_KEEPALIVE, true)
             .childOption(ChannelOption.TCP_NODELAY, true)
             .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
             .childHandler(new ChannelInitializer<SocketChannel>() {
                 @Override
                 protected void initChannel(SocketChannel ch) {
                     ChannelPipeline pipeline = ch.pipeline();
                     pipeline.addLast(new HttpClientCodec());
                     pipeline.addLast(new HttpObjectAggregator(MAX_CONTENT_LENGTH));
                     pipeline.addLast(WebSocketServerCompressionHandler.INSTANCE);
                     pipeline.addLast(new IdleStateHandler(READ_IDLE_TIME, WRITE_IDLE_TIME, ALL_IDLE_TIME, TimeUnit.SECONDS));
                     pipeline.addLast(new WebSocketServerProtocolHandler("/ws", null, true, MAX_CONTENT_LENGTH));
                     pipeline.addLast(businessGroup, new WebSocketFrameHandler(messageHandler));
                 }
             });
            
            ChannelFuture f = b.bind(port).sync();
            System.out.println("WebSocket服务器启动在端口: " + port + " (使用" + 
                    (isLinux() ? "Epoll" : "NIO") + "事件循环)");
            f.channel().closeFuture().sync();
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            bossGroup.shutdownGracefully();
            workerGroup.shutdownGracefully();
            businessGroup.shutdownGracefully();
        }
    }
    
    /**
     * 连接到 WebSocket 服务器
     */
    public static Channel connectToServer(String uri, WebSocketMessageHandler messageHandler) {
        try {
            URI webSocketUri = new URI(uri);
            String scheme = webSocketUri.getScheme() == null ? "ws" : webSocketUri.getScheme();
            final String host = webSocketUri.getHost() == null ? "127.0.0.1" : webSocketUri.getHost();
            final int port;
            
            if (webSocketUri.getPort() == -1) {
                if ("ws".equalsIgnoreCase(scheme)) {
                    port = 80;
                } else if ("wss".equalsIgnoreCase(scheme)) {
                    port = 443;
                } else {
                    port = -1;
                }
            } else {
                port = webSocketUri.getPort();
            }
            
            final boolean ssl = "wss".equalsIgnoreCase(scheme);
            final SslContext sslCtx;
            
            if (ssl) {
                sslCtx = SslContextBuilder.forClient()
                        .trustManager(InsecureTrustManagerFactory.INSTANCE).build();
            } else {
                sslCtx = null;
            }
            
            EventLoopGroup group = isLinux() ? new EpollEventLoopGroup() : new NioEventLoopGroup();
            
            Bootstrap b = new Bootstrap();
            b.group(group)
             .channel(isLinux() ? EpollSocketChannel.class : NioSocketChannel.class)
             .option(ChannelOption.SO_KEEPALIVE, true)
             .option(ChannelOption.TCP_NODELAY, true)
             .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
             .handler(new ChannelInitializer<SocketChannel>() {
                 @Override
                 protected void initChannel(SocketChannel ch) {
                     ChannelPipeline p = ch.pipeline();
                     if (sslCtx != null) {
                         p.addLast(sslCtx.newHandler(ch.alloc(), host, port));
                     }
                     p.addLast(
                             new HttpClientCodec(),
                             new HttpObjectAggregator(MAX_CONTENT_LENGTH),
                             WebSocketClientCompressionHandler.INSTANCE,
                             new IdleStateHandler(READ_IDLE_TIME, WRITE_IDLE_TIME, ALL_IDLE_TIME, TimeUnit.SECONDS),
                             businessGroup,
                             new WebSocketFrameHandler(messageHandler)
                     );
                 }
             });
            
            Channel ch = b.connect(host, port).sync().channel();
            
            // 初始化握手
            WebSocketClientHandshaker handshaker = WebSocketClientHandshakerFactory.newHandshaker(
                    webSocketUri, WebSocketVersion.V13, null, true, new DefaultHttpHeaders());
            handshaker.handshake(ch);
            handshaker.handshakeFuture().sync();
            
            // 注册连接
            String connectionId = ch.id().asLongText();
            ch.attr(CONNECTION_ID).set(connectionId);
            CONNECTION_MAP.put(connectionId, ch);
            
            return ch;
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }
    
    /**
     * 向指定连接发送文本消息
     */
    public static void sendTextMessage(String connectionId, String message) {
        Channel channel = CONNECTION_MAP.get(connectionId);
        if (channel != null && channel.isActive()) {
            channel.writeAndFlush(new TextWebSocketFrame(message));
        }
    }
    
    /**
     * 向指定用户发送文本消息
     */
    public static void sendTextMessageToUser(String userId, String message) {
        String connectionId = USER_CONNECTION_MAP.get(userId);
        if (connectionId != null) {
            sendTextMessage(connectionId, message);
        }
    }
    
    /**
     * 向所有连接广播文本消息
     */
    public static void broadcastTextMessage(String message) {
        TextWebSocketFrame frame = new TextWebSocketFrame(message);
        CONNECTION_MAP.values().forEach(channel -> {
            if (channel.isActive()) {
                channel.writeAndFlush(frame.retainedDuplicate());
            }
        });
        frame.release();
    }
    
    /**
     * 向除指定连接外的所有连接广播文本消息
     */
    public static void broadcastTextMessageExcept(String excludeConnectionId, String message) {
        TextWebSocketFrame frame = new TextWebSocketFrame(message);
        CONNECTION_MAP.forEach((id, channel) -> {
            if (!id.equals(excludeConnectionId) && channel.isActive()) {
                channel.writeAndFlush(frame.retainedDuplicate());
            }
        });
        frame.release();
    }
    
    /**
     * 关闭指定连接
     */
    public static void closeConnection(String connectionId) {
        Channel channel = CONNECTION_MAP.get(connectionId);
        if (channel != null) {
            channel.writeAndFlush(new CloseWebSocketFrame());
            channel.close();
        }
    }
    
    /**
     * 关闭所有连接
     */
    public static void closeAllConnections() {
        CONNECTION_MAP.values().forEach(channel -> {
            channel.writeAndFlush(new CloseWebSocketFrame());
            channel.close();
        });
        CONNECTION_MAP.clear();
        USER_CONNECTION_MAP.clear();
    }
    
    /**
     * 绑定用户ID到连接
     */
    public static void bindUserId(String connectionId, String userId) {
        Channel channel = CONNECTION_MAP.get(connectionId);
        if (channel != null) {
            channel.attr(USER_ID).set(userId);
            USER_CONNECTION_MAP.put(userId, connectionId);
        }
    }
    
    /**
     * 获取当前连接数
     */
    public static int getConnectionCount() {
        return CONNECTION_MAP.size();
    }
    
    /**
     * 检查是否为Linux系统
     */
    private static boolean isLinux() {
        return System.getProperty("os.name").toLowerCase().contains("linux");
    }
    
    /**
     * WebSocket消息处理器接口
     */
    public interface WebSocketMessageHandler {
        void onTextMessage(Channel channel, String message);
        void onBinaryMessage(Channel channel, byte[] data);
        void onPing(Channel channel);
        void onPong(Channel channel);
        void onClose(Channel channel);
        void onError(Channel channel, Throwable cause);
    }
    
    /**
     * WebSocket帧处理器
     */
    private static class WebSocketFrameHandler extends SimpleChannelInboundHandler<WebSocketFrame> {
        private final WebSocketMessageHandler messageHandler;
        
        public WebSocketFrameHandler(WebSocketMessageHandler messageHandler) {
            this.messageHandler = messageHandler;
        }
        
        @Override
        protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame frame) {
            Channel channel = ctx.channel();
            
            try {
                if (frame instanceof TextWebSocketFrame) {
                    TextWebSocketFrame textFrame = (TextWebSocketFrame) frame;
                    messageHandler.onTextMessage(channel, textFrame.text());
                } else if (frame instanceof BinaryWebSocketFrame) {
                    BinaryWebSocketFrame binaryFrame = (BinaryWebSocketFrame) frame;
                    byte[] data = new byte[binaryFrame.content().readableBytes()];
                    binaryFrame.content().readBytes(data);
                    messageHandler.onBinaryMessage(channel, data);
                } else if (frame instanceof PingWebSocketFrame) {
                    channel.writeAndFlush(new PongWebSocketFrame(frame.content().retain()));
                    messageHandler.onPing(channel);
                } else if (frame instanceof PongWebSocketFrame) {
                    messageHandler.onPong(channel);
                } else if (frame instanceof CloseWebSocketFrame) {
                    channel.writeAndFlush(frame.retain());
                    ctx.close();
                    messageHandler.onClose(channel);
                }
            } catch (Exception e) {
                messageHandler.onError(channel, e);
            } finally {
                frame.release();
            }
        }
        
        @Override
        public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
            if (evt instanceof IdleStateEvent) {
                IdleStateEvent e = (IdleStateEvent) evt;
                if (e.state() == IdleState.READER_IDLE) {
                    // 读超时,关闭连接
                    ctx.close();
                } else if (e.state() == IdleState.WRITER_IDLE) {
                    // 写超时,发送心跳
                    ctx.channel().writeAndFlush(new PingWebSocketFrame());
                }
            }
        }
        
        @Override
        public void channelActive(ChannelHandlerContext ctx) {
            Channel channel = ctx.channel();
            String connectionId = channel.id().asLongText();
            channel.attr(CONNECTION_ID).set(connectionId);
            CONNECTION_MAP.put(connectionId, channel);
        }
        
        @Override
        public void channelInactive(ChannelHandlerContext ctx) {
            Channel channel = ctx.channel();
            String connectionId = channel.id().asLongText();
            CONNECTION_MAP.remove(connectionId);
            
            // 移除用户ID绑定
            String userId = channel.attr(USER_ID).get();
            if (userId != null) {
                USER_CONNECTION_MAP.remove(userId);
            }
            
            messageHandler.onClose(channel);
        }
        
        @Override
        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
            messageHandler.onError(ctx.channel(), cause);
            ctx.close();
        }
    }
}