基于springboot netty 实现的自定义通讯协议

4,335 阅读8分钟

基于netty实现的自定义协议通讯协议

1. 通讯协议定义

字段占用的字节数描述
帧 头2 bytes固定为 0x55 0xAA
长 度2 bytes长度 = 命令字 + 参数 + 校验和 ,不包括帧头和长度字节
命 令1 bytes0 心跳, 1 认证, 2 获取信息
参 数0~65535 bytes业务数据
校验和2 bytes校验和 = 帧头 + 长度 + 命令字 + 参数的字节累加和

框架功能

  1. 心跳机制

  2. TCP半包,黏包处理

  3. IP过滤

  4. 日志打印

  5. 自定义协议解析

业务描述

(1)Netty 协议栈客户端发送握手请求消息,携带认证信息;

(2)Netty 协议栈服务端对握手请求消息进行合法性校验,校验通过后,返回登录成功的握手应答消息;

(3)链路建立成功之后,客户端发送心跳消息, 客户端发送业务消息;

(6)服务端响应心跳和业务消息;

(7)服务端退出时,服务端关闭连接,客户端感知对方关闭连接后,被动关闭客户端连接。

完整代码下载地址

代码截图

image.png

客户端启动代码

package com.king.netty.core.client;

import com.king.netty.core.DataFrameDecoder;
import com.king.netty.core.DataFrameEncoder;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;

/**
 * @author King
 * @date 2021/7/14
 */
public class NettyClient {

    public static void main(String[] args) throws InterruptedException {
        startServer();
    }

    static void startServer() throws InterruptedException {
        NioEventLoopGroup group = new NioEventLoopGroup();
        Bootstrap bootstrap = new Bootstrap();
        bootstrap
                // 设置线程组
                .group(group)
                // 设置为NIO模式
                .channel(NioSocketChannel.class)
                // 设置pipeline中的全部的channelHandler
                // 入站方向的channelHandler需要保证顺序
                // 出站方向的channelHandler需要保证顺序
                .handler(new ClientHandlerInit());
        bootstrap.connect("127.0.0.1", 8888).sync();
    }

    static class ClientHandlerInit extends ChannelInitializer<SocketChannel>{
        @Override
        protected void initChannel(SocketChannel ch) throws Exception {
            ChannelPipeline pipeline = ch.pipeline();
            // 日志打印
            pipeline.addLast(new LoggingHandler(LogLevel.INFO));
            // LengthFieldBasedFrameDecoder 用于解决TCP黏包半包问题
            pipeline.addLast(new LengthFieldBasedFrameDecoder(
                    65535,              // maxFrameLength       消息最大长度
                    2,               // lengthFieldOffset    指的是长度域的偏移量,表示跳过指定个数字节之后的才是长度域
                    2,               // lengthFieldLength    记录该帧数据长度的字段,也就是长度域本身的长度
                    0,                // lengthAdjustment     长度的一个修正值,可正可负,Netty 在读取到数据包的长度值 N 后, 认为接下来的 N 个字节都是需要读取的,但是根据实际情况,有可能需要增加 N 的值,也 有可能需要减少 N 的值,具体增加多少,减少多少,写在这个参数里
                    2              // initialBytesToStrip  从数据帧中跳过的字节数,表示得到一个完整的数据包之后,扔掉 这个数据包中多少字节数,才是后续业务实际需要的业务数据。
            ));
            // 自定义协议解码器
            pipeline.addLast(new DataFrameDecoder());
            // 自定义协议编码器
            pipeline.addLast(new DataFrameEncoder());
            // 处理认证请求的handler
            pipeline.addLast(new AuthorizationRequestHandler());
            // 处理心跳的handler
            pipeline.addLast(new HeartBeatRequestHandler());
            // 客户端业务handler
            pipeline.addLast(new ClientBusinessHandler());
        }
    }
}

客户端请求认证Handler代码

package com.king.netty.core.client;

import com.king.netty.core.DataFrame;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.ReferenceCountUtil;

/**
 * @author King
 * @date 2021/7/14
 */
public class AuthorizationRequestHandler extends ChannelInboundHandlerAdapter {

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        // 连接成功后发起认证请求
        ctx.writeAndFlush(DataFrame.getAuthorizationDataFrame());
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        DataFrame dataFrame = (DataFrame) msg;
        // 处理认证响应
        if (dataFrame.getCmd() == DataFrame.CMD_AUTHORIZATION) {
            byte[] params = dataFrame.getParams();
            if (! "success".equals(new String(params))){
                // 认证失败,关闭连接
                ReferenceCountUtil.release(msg);
                ctx.close();
            }
        }
        // 认证成功,继续传递消息
        // 非认证的响应,交给后续业务处理
        ctx.fireChannelRead(msg);
    }
}

客户端心跳Handler代码

package com.king.netty.core.client;

import com.king.netty.core.DataFrame;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.ReferenceCountUtil;

import java.util.concurrent.TimeUnit;

/**
 * @author King
 * @date 2021/7/14
 */
public class HeartBeatRequestHandler extends ChannelInboundHandlerAdapter {

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        DataFrame dataFrame = (DataFrame) msg;
        switch (dataFrame.getCmd()){
            // 如果是心跳应答, release掉, 因为后续的业务handler关心
            case DataFrame.CMD_HEART_BEAT:
                ReferenceCountUtil.release(msg);
                break;
            // 如果是认证成功的响应, 定时发送心跳
            case DataFrame.CMD_AUTHORIZATION:
                // 使用netty自带的任务处理器, 10s发送一次心跳
                ctx.executor().scheduleAtFixedRate(() -> {
                    ctx.writeAndFlush(DataFrame.getHeartBeatDataFrame());
                }, 0, 10, TimeUnit.SECONDS);
                ctx.fireChannelRead(msg);
                break;
            default:
                // 向后传递消息,让业务handler处理
                ctx.fireChannelRead(msg);
                break;
        }
    }
}

客户端业务Handler代码

package com.king.netty.core.client;

import com.king.netty.core.DataFrame;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.ReferenceCountUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * @author King
 * @date 2021/7/14
 */
public class ClientBusinessHandler extends ChannelInboundHandlerAdapter {

    public static final Logger logger = LoggerFactory.getLogger(ClientBusinessHandler.class);

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        DataFrame dataFrame = (DataFrame) msg;
        if (dataFrame.getCmd() == DataFrame.CMD_AUTHORIZATION) {
            // 发送业务请求
            ctx.writeAndFlush(new DataFrame(DataFrame.CMD_GET_INFO, "which language is the best ?".getBytes()));
        }else {
            // 打印服务器发送的消息
            logger.debug("receive message: " + dataFrame);
        }
        ReferenceCountUtil.release(msg);
    }
}

服务器启动代码

package com.king.netty.core.server;

import com.king.netty.core.DataFrameDecoder;
import com.king.netty.core.DataFrameEncoder;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.ipfilter.IpFilterRule;
import io.netty.handler.ipfilter.IpFilterRuleType;
import io.netty.handler.ipfilter.RuleBasedIpFilter;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.timeout.ReadTimeoutHandler;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Component;

import java.net.InetSocketAddress;

/**
 * @author King
 * @date 2021/7/14
 */
@Component
public class NettyServer implements InitializingBean, DisposableBean {

    private boolean started;
    private Channel channel;
    private NioEventLoopGroup parentGroup;
    private NioEventLoopGroup childGroup;

    @Override
    public void destroy() throws Exception {
        // spring销毁对象时调用stop释放服务器
        if (started){
            stopServer();
        }
    }

    @Override
    public void afterPropertiesSet() throws Exception {
        // spring初始化对象后, 调用启动方法,启动服务
        if (started){
            return;
        }
        startServer();
    }

    void startServer() throws InterruptedException {
        this.parentGroup = new NioEventLoopGroup();
        this.childGroup = new NioEventLoopGroup();
        ServerBootstrap serverBootstrap = new ServerBootstrap();
        serverBootstrap
                // 设置线程组
                .group(parentGroup, childGroup)
                // 设置为NIO模式
                .channel(NioServerSocketChannel.class)
                // 设置TCP sync队列大小, 防止洪泛攻击
                .childOption(ChannelOption.SO_BACKLOG, 1024)
                // 设置pipeline中的全部的channelHandler
                // 入站方向的channelHandler需要保证顺序
                // 出站方向的channelHandler需要保证顺序
                .childHandler(new ServerHandlerInit());
        this.channel = serverBootstrap.bind(8888).sync().channel();
        started = true;
    }

    void stopServer(){
        try{
            parentGroup.shutdownGracefully();
            childGroup.shutdownGracefully();
            channel.closeFuture().syncUninterruptibly();
        }finally {
            this.parentGroup = null;
            this.childGroup = null;
            this.channel = null;
            started = false;
        }
    }

    static class ServerHandlerInit extends ChannelInitializer<SocketChannel>{
        @Override
        protected void initChannel(SocketChannel ch) throws Exception {
            ChannelPipeline pipeline = ch.pipeline();
            // 日志打印
            pipeline.addLast(new LoggingHandler(LogLevel.INFO));
            // IP过滤
            pipeline.addLast(new RuleBasedIpFilter(new IpFilterRule() {
                @Override
                public boolean matches(InetSocketAddress remoteAddress) {
                    // 自定义IP地址拦截器,  非127开头的IP不允许连接
                    return ! remoteAddress.getHostName().startsWith("127");
                }
                @Override
                public IpFilterRuleType ruleType() {
                    return IpFilterRuleType.REJECT;
                }
            }));
            // LengthFieldBasedFrameDecoder 用于解决TCP黏包半包问题
            pipeline.addLast(new LengthFieldBasedFrameDecoder(
                    65535,              // maxFrameLength       消息最大长度
                    2,               // lengthFieldOffset    指的是长度域的偏移量,表示跳过指定个数字节之后的才是长度域
                    2,               // lengthFieldLength    记录该帧数据长度的字段,也就是长度域本身的长度
                    0,                // lengthAdjustment     长度的一个修正值,可正可负,Netty 在读取到数据包的长度值 N 后, 认为接下来的 N 个字节都是需要读取的,但是根据实际情况,有可能需要增加 N 的值,也 有可能需要减少 N 的值,具体增加多少,减少多少,写在这个参数里
                    2              // initialBytesToStrip  从数据帧中跳过的字节数,表示得到一个完整的数据包之后,扔掉 这个数据包中多少字节数,才是后续业务实际需要的业务数据。
            ));
            // 设置心跳的超时时间 30s, 如果30s内未收到心跳则会抛出ReadTimeoutException
            pipeline.addLast(new ReadTimeoutHandler(30));
            // 自定义协议解码器
            pipeline.addLast(new DataFrameDecoder());
            // 自定义协议编码器
            pipeline.addLast(new DataFrameEncoder());
            // 认证处理
            pipeline.addLast(new AuthorizationResponseHandler());
            // 心跳处理
            pipeline.addLast(new HeartBeatResponseHandler());
            // 业务处理handler
            pipeline.addLast(new ServerBusinessHandler());
        }
    }
}

服务器认证处理Handler

package com.king.netty.core.server;

import com.king.netty.core.DataFrame;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.ReferenceCountUtil;

/**
 * @author King
 * @date 2021/7/14
 */
public class AuthorizationResponseHandler extends ChannelInboundHandlerAdapter {

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        DataFrame dataFrame = (DataFrame) msg;
        if (dataFrame.getCmd() == DataFrame.CMD_AUTHORIZATION) {
            String auth = "{\"username\":\"test\", \"password\":\"abcdef\"}";
            byte[] params = dataFrame.getParams();
            if (auth.equals(new String(params))){
                // 认证成功
                ctx.writeAndFlush(new DataFrame(dataFrame.getCmd(), "success".getBytes()));
            }else {
                // 认证失败
                ctx.writeAndFlush(new DataFrame(dataFrame.getCmd(), "fail".getBytes()));
            }
            // 释放消息
            ReferenceCountUtil.release(msg);
        }else {
            // 非认证的请求,交给后续业务处理
            ctx.fireChannelRead(msg);
        }
    }
}

服务器心跳处理Handler

package com.king.netty.core.server;

import com.king.netty.core.DataFrame;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.timeout.ReadTimeoutException;
import io.netty.util.ReferenceCountUtil;

/**
 * @author King
 * @date 2021/7/14
 */
public class HeartBeatResponseHandler extends ChannelInboundHandlerAdapter {

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        DataFrame dataFrame = (DataFrame) msg;
        // 如果是心跳请求, release掉, 因为后续的业务handler关心
        if (dataFrame.getCmd() == DataFrame.CMD_HEART_BEAT) {
            ctx.writeAndFlush(DataFrame.getHeartBeatDataFrame());
            ReferenceCountUtil.release(msg);
        } else {// 向后传递消息,让业务handler处理
            ctx.fireChannelRead(msg);
        }
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        if (cause instanceof ReadTimeoutException){
            // 断开客户端连接
            ctx.close();
            return;
        }
        super.exceptionCaught(ctx, cause);
    }
}

服务器业务处理Handler

package com.king.netty.core.server;

import com.king.netty.core.DataFrame;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * @author King
 * @date 2021/7/14
 */
public class ServerBusinessHandler extends ChannelInboundHandlerAdapter {

    public static final Logger logger = LoggerFactory.getLogger(ServerBusinessHandler.class);

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        DataFrame dataFrame = (DataFrame) msg;
        logger.debug("receive message: " + dataFrame);

        // 返回客户端数据
        DataFrame response = doBusiness(dataFrame);
        ctx.writeAndFlush(response);
    }

    private DataFrame doBusiness(DataFrame dataFrame){
        // 处理自己的业务
        // todo
        // 响应客户端
        return new DataFrame(dataFrame.getCmd(), "java is the best language".getBytes());
    }
}

协议定义代码

package com.king.netty.core;

import lombok.Data;

/**
 * @author King
 * @date 2021/7/14
 */
@Data
public class DataFrame {

    public static final byte CMD_HEART_BEAT = 0;
    public static final byte CMD_AUTHORIZATION = 1;
    public static final byte CMD_GET_INFO = 2;

    /**
     * 帧     头	        长 度	    命 令	    参 数	            校验和
     * 0x55 0xAA	        2byte	    1byte	    0~1476bytes	        2bytes
     *
     * 长度 = 命令字 + 参数 + 校验和 ,不包括帧头和长度字节;
     * 校验和 = 帧头 + 长度 + 命令字 + 参数的字节累加和。
     *
     */

    public static final byte[] HEADER = new byte[] {0b01010101, (byte) 0b10101010};

    private byte cmd;

    private byte[] params;

    private int crc;

    public DataFrame(byte cmd, byte[] params, int crc) {
        this.cmd = cmd;
        this.params = params;
        this.crc = crc;
    }

    public DataFrame(byte cmd, byte[] params) {
        this.cmd = cmd;
        this.params = params;
        this.crc = getCrc();
    }

    public boolean checkCrc(){
        return getCrc() == this.crc;
    }

    public int getLength() {
        // 长度 = 命令字 + 参数 + 校验和 ,不包括帧头和长度字节;
        return  1 + params.length + 2;
    }

    public int getCrc(){
        // 校验和 = 帧头 + 长度 + 命令字 + 参数的字节累加和。
        int crc = 0;
        // 帧头
        crc += 0b01010101;
        crc += 0b10101010;
        // 长度
        crc += getLength();
        // 参数和
        for (byte b: params){
            crc += (b & 0xFF);
        }
        return crc;
    }

    public static DataFrame getHeartBeatDataFrame(){
        return new DataFrame(DataFrame.CMD_HEART_BEAT, new byte[]{});
    }

    public static DataFrame getAuthorizationDataFrame(){
        String msg = "{\"username\":\"test\", \"password\":\"abcdef\"}";
        return new DataFrame(DataFrame.CMD_AUTHORIZATION, msg.getBytes());
    }

    @Override
    public String toString() {
        return "DataFrame{" +
                "cmd=" + cmd +
                ", params=" + new String(params) +
                '}';
    }
}

协议解码器

package com.king.netty.core;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;

import java.util.List;

/**
 * @author King
 * @date 2021/7/14
 */
public class DataFrameDecoder extends ByteToMessageDecoder {

    /**
     * 帧     头	        长 度	    命 令	    参 数	            校验和
     * 0x55 0xAA	        2byte	    1byte	    0~1476bytes	        2bytes
     *
     * 长度 = 命令字 + 参数 + 校验和 ,不包括帧头和长度字节;
     * 校验和 = 帧头 + 长度 + 命令字 + 参数的字节累加和。
     *
     */

    @Override
    protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
        // 长度位
        int length = in.readShort();
        // 命令位
        byte cmd = in.readByte();
        // 参数
        byte[] params = new byte[length-3];
        in.readBytes(params);
        // 校验和
        int crc = in.readShort();
        DataFrame dataFrame = new DataFrame(cmd, params, crc);
        // 计算校验和
        if (dataFrame.checkCrc()){
            // 将解析后的数据加入到list中,传递给后续的channelHandler
            out.add(dataFrame);
        };
    }
}

协议编码器

package com.king.netty.core;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToByteEncoder;

/**
 * @author King
 * @date 2021/7/14
 */
public class DataFrameEncoder extends MessageToByteEncoder<DataFrame> {

    @Override
    protected void encode(ChannelHandlerContext ctx, DataFrame msg, ByteBuf out) throws Exception {
        // 写出帧头
        out.writeBytes(DataFrame.HEADER);
        // 写出长度
        out.writeShort(msg.getLength());
        // 写出命令
        out.writeByte(msg.getCmd());
        // 参 数
        out.writeBytes(msg.getParams());
        // 校验和
        out.writeShort(msg.getCrc());
    }
}