Spring boot中集成Netty示例

138 阅读2分钟

在我们的业务场景中需要实现客户端实时接收到最新发布的活动信息,所以我们在服务端使用netty作为实现的网络编程框架。代码实现如下:

首先,添加所需的依赖,包括Netty和Spring Boot:

<!-- Netty -->
<dependency>
    <groupId>io.netty</groupId>
    <artifactId>netty-all</artifactId>
    <version>4.1.69.Final</version>
</dependency>

<!-- Spring Boot WebSocket -->
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-websocket</artifactId>
</dependency>

接下来,创建WebSocket服务器,处理连接和消息传输。这里是完整的代码:

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.*;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.util.AttributeKey;
import io.netty.util.concurrent.GlobalEventExecutor;
import org.springframework.stereotype.Component;

@Component
public class WebSocketServer {

    private final ChannelGroup channelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);

    public void run(int port) {
        EventLoopGroup bossGroup = new NioEventLoopGroup();
        EventLoopGroup workerGroup = new NioEventLoopGroup();

        try {
            ServerBootstrap bootstrap = new ServerBootstrap()
                    .group(bossGroup, workerGroup)
                    .channel(NioServerSocketChannel.class)
                    .childHandler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        protected void initChannel(SocketChannel ch) {
                            ChannelPipeline pipeline = ch.pipeline();
                            pipeline.addLast(new HttpServerCodec());
                            pipeline.addLast(new ChunkedWriteHandler());
                            pipeline.addLast(new HttpObjectAggregator(65536));
                            pipeline.addLast(new WebSocketServerProtocolHandler("/websocket"));
                            pipeline.addLast(new WebSocketHandler(channelGroup));
                        }
                    })
                    .option(ChannelOption.SO_BACKLOG, 128)
                    .childOption(ChannelOption.SO_KEEPALIVE, true);

            ChannelFuture future = bootstrap.bind(port).sync();
            future.channel().closeFuture().sync();
        } catch (InterruptedException e) {
            e.printStackTrace();
        } finally {
            bossGroup.shutdownGracefully();
            workerGroup.shutdownGracefully();
        }
    }
}

上述示例中的 WebSocketHandler 类将处理WebSocket连接、用户身份验证、消息接收和发送。接下来,创建 WebSocketHandler 类:

import io.netty.channel.*;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.util.AttributeKey;
import io.netty.util.concurrent.GlobalEventExecutor;
import org.springframework.stereotype.Component;

@Component
@ChannelHandler.Sharable
public class WebSocketHandler extends SimpleChannelInboundHandler<TextWebSocketFrame> {

    private static final AttributeKey<String> USER_ID = AttributeKey.valueOf("userId");
    private final ChannelGroup channelGroup;

    public WebSocketHandler(ChannelGroup channelGroup) {
        this.channelGroup = channelGroup;
    }

    @Override
    public void handlerAdded(ChannelHandlerContext ctx) {
        // 新连接建立时
        channelGroup.add(ctx.channel());
    }

    @Override
    public void handlerRemoved(ChannelHandlerContext ctx) {
        // 连接断开时
        channelGroup.remove(ctx.channel());
    }

    @Override
    public void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame frame) {
        String userId = ctx.channel().attr(USER_ID).get();
        if (userId != null) {
            // 获取消息内容
            String message = frame.text();

            // 发送消息给指定用户
            sendToUser(userId, message);
        }
    }

    private void sendToUser(String userId, String message) {
        channelGroup.stream()
                .filter(channel -> userId.equals(channel.attr(USER_ID).get()))
                .forEach(channel -> {
                    channel.writeAndFlush(new TextWebSocketFrame(message));
                });
    }

    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
        if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
            // WebSocket 握手成功,可以进行用户身份验证等操作
            // 在这里设置用户ID
            ctx.channel().attr(USER_ID).set("user123");
        } else {
            super.userEventTriggered(ctx, evt);
        }
    }
}

上述示例中,我们使用 AttributeKey 来将用户ID与通道关联起来。在 userEventTriggered 方法中,您可以执行用户身份验证并将用户ID与通道关联。 channelRead0 方法用于接收和发送消息。

最后,在Spring Boot应用程序中启动WebSocket服务器:

@SpringBootApplication
public class WebSocketApplication {
    public static void main(String[] args) {
        SpringApplication.run(WebSocketApplication.class, args);

        WebSocketServer server = new WebSocketServer();
        server.run(8080);
    }
}

使用 HikariCP 来限制连接数量:

<dependency>
    <groupId>com.zaxxer</groupId>
    <artifactId>HikariCP</artifactId>
    <version>4.0.3</version>
</dependency>

接下来,创建一个WebSocket连接池,并在WebSocketHandler中使用它来管理连接:

import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;

import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;

import javax.sql.DataSource;

public class WebSocketHandler extends SimpleChannelInboundHandler<TextWebSocketFrame> {
    private static final int MAX_CONNECTIONS = 1000; // 设置最大连接数
    private static final DataSource dataSource = createDataSource();
    private final HikariDataSource connectionPool = new HikariDataSource(dataSource);

    private static DataSource createDataSource() {
        HikariConfig config = new HikariConfig();
        config.setJdbcUrl("jdbc:h2:mem:test;DB_CLOSE_DELAY=-1");
        config.setUsername("sa");
        config.setPassword("");
        config.setMaximumPoolSize(MAX_CONNECTIONS);
        return new HikariDataSource(config);
    }

    @Override
    public void handlerAdded(ChannelHandlerContext ctx) {
        if (connectionPool.getJdbc4ConnectionTimeout() >= MAX_CONNECTIONS) {
            // 连接数已达到上限,拒绝连接
            ctx.close();
            return;
        }
        // 添加到连接池
        connectionPool.addConnection(ctx.channel());
    }

    @Override
    public void handlerRemoved(ChannelHandlerContext ctx) {
        // 从连接池中移除连接
        connectionPool.removeConnection(ctx.channel());
    }

    @Override
    public void channelInactive(ChannelHandlerContext ctx) {
        // 在连接关闭时从连接池中移除连接
        connectionPool.removeConnection(ctx.channel());
    }

    // 其他处理消息的方法
}