手写RPC框架<一>通信层开发

135 阅读9分钟

一、通信层设计

image.png

二、实现自定义通信协议

无论是使用 Netty 还是原始的 Socket 编程,基于 TCP 通信的数据包格式均为二进制,协议指的就是客户端与服务端事先商量好的,每一个二进制数据包中每一段字节分别代表什么含义的规则。

2.1客户端与服务端通信

image.png

  1. 首先,客户端把一个 Java 对象按照通信协议转换成二进制数据包。
  2. 然后通过网络,把这段二进制数据包发送到服务端,数据的传输过程由 TCP/IP 协议负责数据的传输,与我们的应用层无关。
  3. 服务端接收到数据之后,按照协议取出二进制数据包中的相应字段,包装成 Java 对象,交给应用逻辑处理。
  4. 服务端处理完之后,如果需要吐出响应给客户端,那么按照相同的流程进行。

2.2通信协议设计

image.png

  1. 首先,第一个字段是魔数,通常情况下为固定的几个字节(我们这边规定为4个字节)。 有了这个魔数之后,服务端首先取出前面四个字节进行比对,能够在第一时间识别出这个数据包并非是遵循自定义协议的,也就是无效数据包,为了安全考虑可以直接关闭连接以节省资源。在 Java 的字节码的二进制文件中,开头的 4 个字节为0xcafebabe 用来标识这是个字节码文件
  2. 接下来一个字节为版本号,通常情况下是预留字段,用于协议升级的时候用到
  3. 第三部分,序列化算法表示如何把 Java 对象转换二进制数据以及二进制数据如何转换回 Java 对象,比如 Java 自带的序列化,json,hessian 等序列化方式。
  4. 第四部分的字段表示指令,关于指令相关的介绍,我们在前面已经讨论过,服务端或者客户端每收到一种指令都会有相应的处理逻辑。
  5. 接下来的字段为数据部分的长度,占四个字节。
  6. 最后一个部分为实际传输的数据内容

2.3 通信协议的实现

2.3.1 定义数据包

客户端与服务端的通信数据都要继承该类

public abstract class Packet {
    /**
     * 默认协议版本
     */
    private Byte version = 1;

    /**
     * 指令
     * @return
     */
    public abstract Byte getCommand();

    /**
     * 序列化算法
     * @return
     */
    public abstract Byte getSerializerAlgorithm();

    public Byte getVersion() {
        return version;
    }

    public void setVersion(Byte version) {
        this.version = version;
    }
}

2.3.2 实现序列化算法

序列化接口

public interface Serializer {

    /**
     * 序列化算法
     */
    byte getSerializerAlgorithm();

    /**
     * java 对象转换成二进制
     */
    byte[] serialize(Object object);

    /**
     * 二进制转换成 java 对象
     */
    <T> T deserialize(byte[] bytes,Class<T> clazz);

}

默认使用FastJson序列化

public class FastjsonSerializer implements Serializer {

    @Override
    public byte getSerializerAlgorithm() {
        return SerializerEnum.FAST_JSON.getSerializerType();
    }

    @Override
    public byte[] serialize(Object object) {
        return JSON.toJSONBytes(object);
    }

    @Override
    public <T> T deserialize( byte[] bytes,Class<T> clazz) {

        return JSON.parseObject(bytes, clazz);
    }
}

使用枚举创建序列化单例对象

public enum SerializerEnum {

    /**
     * fastJson
     */
    FAST_JSON((byte) 1, new FastjsonSerializer());

    /**
     * 序列化类型
     */
    private byte serializerType;

    /**
     * 序列化器
     */
    private Serializer serializer;


    SerializerEnum(byte serializerType, Serializer serializer) {
        this.serializerType = serializerType;
        this.serializer = serializer;
    }

    private static final SerializerEnum[] VALUES = SerializerEnum.values();


    public byte getSerializerType() {
        return serializerType;
    }

    public Serializer getSerializer() {
        return serializer;
    }
}

2.3.3 实现默认传输业务指令

定义指令集合

public interface Command {

    /**
     * rpc对象
     */
    Byte RPC_INVOCATION = 1;
}

实现通信数据包

public class RpcInvocation extends PacketResponse {

    //请求信息
    private String reqMsg;

    //响应信息
    private Object response;

    public void success(Object response) {
        this.response = response;
        super.success();
    }

    @Override
    public Byte getCommand() {
        return RPC_INVOCATION;
    }

    @Override
    public Byte getSerializerAlgorithm() {
        return SerializerEnum.FAST_JSON.getSerializerType();
    }

    public String getReqMsg() {
        return reqMsg;
    }

    public void setReqMsg(String reqMsg) {
        this.reqMsg = reqMsg;
    }

    public Object getResponse() {
        return response;
    }

    public void setResponse(Object response) {
        this.response = response;
    }

    @Override
    public String toString() {
        return "RpcInvocation{" +
                "reqMsg='" + reqMsg + ''' +
                ", response=" + response +
                '}';
    }
}

2.3.4 数据包的编码与解码过程

根据定义的数据包我们需要将定义的java对象转化为二进制

public class PacketCodeC {

    //魔术
    public static final int MAGIC_NUMBER = 0x12345678;
    //指令集合
    private final Map<Byte, Class<? extends Packet>> packetTypeMap = new HashMap<>();
    //序列化列表
    private final Map<Byte, Serializer> serializerHashMap = new HashMap<>();

    public PacketCodeC() {
        packetTypeMap.put(RPC_INVOCATION, RpcInvocation.class);
        for (SerializerEnum serializerEnum : SerializerEnum.values()) {
            serializerHashMap.put(serializerEnum.getSerializerType(), serializerEnum.getSerializer());
        }
    }

    public ByteBuf encode(ByteBuf byteBuf, Packet packet) {
        //获取序列化器
        Serializer serializer = serializerHashMap.get(packet.getSerializerAlgorithm());
        byte[] bytes = serializer.serialize(packet);
        //1.写入魔术
        byteBuf.writeInt(MAGIC_NUMBER);
        //2.写入版本号
        byteBuf.writeByte(packet.getVersion());
        //3.写入序列化算法
        byteBuf.writeByte(serializer.getSerializerAlgorithm());
        //4.写入业务指令
        byteBuf.writeByte(packet.getCommand());
        //5.写入数据包长度
        byteBuf.writeInt(bytes.length);
        //6.写入数据
        byteBuf.writeBytes(bytes);
        return byteBuf;
    }


    public Packet decode(ByteBuf byteBuf) {
        //1.跳过魔术
        byteBuf.skipBytes(4);
        //2.跳过版本号
        byteBuf.skipBytes(1);
        //3.序列化算法
        byte serializeAlgorithm = byteBuf.readByte();
        Serializer serializer = getSerializer(serializeAlgorithm);
        //4.业务指令
        byte command = byteBuf.readByte();
        //5.数据包长度
        int length = byteBuf.readInt();
        byte[] bytes = new byte[length];
        //6.读取数据
        byteBuf.readBytes(bytes);
        Class<? extends Packet> clazz = getRequestType(command);
        return serializer.deserialize(bytes, clazz);
    }

    private Class<? extends Packet> getRequestType(byte command) {
        return packetTypeMap.get(command);
    }

    private Serializer getSerializer(byte serializerAlgorithm) {
        return serializerHashMap.get(serializerAlgorithm);
    }
}

二、Netty服务端创建

2.1 实现拆包器,解决粘包半包现象

虽然我们使用了 Netty,但是对于操作系统来说,只认 TCP 协议,尽管我们的应用层是按照 ByteBuf 为 单位来发送数据,但是到了底层操作系统仍然是按照字节流发送数据,因此,数据到了服务端,也是按照字节流的方式读入,然后到了 Netty 应用层面,重新拼装成 ByteBuf,而这里的 ByteBuf 与客户端按顺序发送的 ByteBuf 可能是不对等的。因此,我们需要在客户端根据自定义协议来组装我们应用层的数据包,然后在服务端根据我们的应用层的协议来组装数据包,这个过程通常在服务端称为拆包,而在客户端称为粘包。

2.1.1 拆包原理

在没有 Netty 的情况下,用户如果自己需要拆包,基本原理就是不断从 TCP 缓冲区中读取数据,每次读取完都需要判断是否是一个完整的数据包

  1. 如果当前读取的数据不足以拼接成一个完整的业务数据包,那就保留该数据,继续从 TCP 缓冲区中读取,直到得到一个完整的数据包。
  2. 如果当前读到的数据加上已经读取的数据足够拼接成一个数据包,那就将已经读取的数据拼接上本次读取的数据,构成一个完整的业务数据包传递到业务逻辑,多余的数据仍然保留,以便和下次读到的数据尝试拼接。

2.1.2 Netty自带的拆包器

①固定长度的拆包器 FixedLengthFrameDecoder

如果你的应用层协议非常简单,每个数据包的长度都是固定的,比如 100,那么只需要把这个拆包器加到 pipeline 中,Netty 会把一个个长度为 100 的数据包 (ByteBuf) 传递到下一个 channelHandler。

②行拆包器 LineBasedFrameDecoder

从字面意思来看,发送端发送数据包的时候,每个数据包之间以换行符作为分隔,接收端通过 LineBasedFrameDecoder 将粘过的 ByteBuf 拆分成一个个完整的应用层数据包。

③分隔符拆包器 DelimiterBasedFrameDecoder

DelimiterBasedFrameDecoder 是行拆包器的通用版本,只不过我们可以自定义分隔符。

④基于长度域拆包器 LengthFieldBasedFrameDecoder

最后一种拆包器是最通用的一种拆包器,只要你的自定义协议中包含长度域字段,均可以使用这个拆包器来实现应用层拆包。由于我们协议中定义了数据包的长度所以可以使用该种拆包器。

2.1.3 实现基于长度域拆包器

public class Spliter extends LengthFieldBasedFrameDecoder {

    private static final int LENGTH_FIELD_OFFSET = 7;

    private static final int LENGTH_FIELD_LENGTH = 4;

    public Spliter() {
        //第一个参数表示数据包的最大长度 
        //第二个参数指的是长度域的偏移量 
        //第三个参数指的是长度域的长度
        super(Integer.MAX_VALUE, LENGTH_FIELD_OFFSET, LENGTH_FIELD_LENGTH);
    }

    @Override
    protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
        if (in.getInt(in.readerIndex()) != PacketCodeC.MAGIC_NUMBER) {
            ctx.channel().close();
            return null;
        }

        return super.decode(ctx, in);
    }
}

2.2 实现编、解码器

主要功能:字节与java对象的相互转换

@ChannelHandler.Sharable
public class PacketCodecHandler extends MessageToMessageCodec<ByteBuf, Packet> {

    private final PacketCodeC packetCode;

    public PacketCodecHandler() {
        packetCode = new PacketCodeC();
    }
    
    @Override
    protected void decode(ChannelHandlerContext ctx, ByteBuf byteBuf, List<Object> out) {
        //将字节转化为java对象
        out.add(packetCode.decode(byteBuf));
    }

    @Override
    protected void encode(ChannelHandlerContext ctx, Packet packet, List<Object> out) {
        try {
            ByteBuf byteBuf = ctx.channel().alloc().ioBuffer();
            //将java对象转化为字节
            packetCode.encode(byteBuf, packet);
            out.add(byteBuf);
        } finally {
            ReferenceCountUtil.release(packet);
        }

    }
}

2.3定义业务处理器,缩短事件传播路径

@ChannelHandler.Sharable
public class IMHandler extends SimpleChannelInboundHandler<Packet> {

    public static final IMHandler INSTANCE = new IMHandler();
    //业务指令集合
    private Map<Byte, SimpleChannelInboundHandler<? extends Packet>> handlerMap;

    private IMHandler() {
        handlerMap = new HashMap<>();
        handlerMap.put(RPC_INVOCATION, RpcServerHandler.INSTANCE);
    }

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, Packet packet) throws Exception {
        handlerMap.get(packet.getCommand()).channelRead(ctx, packet);
    }
}

2.4 创建Netty服务器

2.4.1 客户端业务处理器

@ChannelHandler.Sharable
public class RpcServerHandler extends SimpleChannelInboundHandler<RpcInvocation> {

    private static final Logger logger = LoggerFactory.getLogger(RpcServerHandler.class);

    public static final RpcServerHandler INSTANCE = new RpcServerHandler();

    private RpcServerHandler(){
    }

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, RpcInvocation rpcInvocation) {
        logger.info("服务端获取到的请求消息为{}", rpcInvocation.getReqMsg());
        rpcInvocation.success("服务端已收到消息!!!");
        ctx.writeAndFlush(rpcInvocation);
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
        cause.printStackTrace();
        Channel channel = ctx.channel();
        if (channel.isActive()) {
            ctx.close();
        }
    }

}

2.4.2 创建服务端

public class RpcServer {

    private static final Logger logger = LoggerFactory.getLogger(RpcServer.class);

    private final EventLoopGroup boss = new NioEventLoopGroup(1);

    private final EventLoopGroup work = new NioEventLoopGroup();

    public ChannelFuture startApplication(int port) {
        ChannelFuture channelFuture = null;
        try {
            ServerBootstrap bootstrap = new ServerBootstrap();

            //服务端采用单一长连接的模式,这里所支持的最大连接数应该和机器本身的性能有关
            //连接防护的handler应该绑定在Main-Reactor上
            bootstrap.group(boss, work)
                    .channel(NioServerSocketChannel.class)
                    .option(ChannelOption.SO_BACKLOG, 128)
                    //添加初始化器
                    .childHandler(new RpcChannelInitializer());
            channelFuture = bind(bootstrap, port).sync();

        } catch (Exception e) {
            logger.error("socket server start error.", e);
        } finally {
            if (null == channelFuture || !channelFuture.isSuccess()) {
                logger.error("socket server start error.");
            }
        }
        logger.info("[startApplication] server is started!");
        return channelFuture;
    }
    /**
     * 绑定端口号
     */
    private ChannelFuture bind(final ServerBootstrap serverBootstrap, final int port) {
        return serverBootstrap.bind(port).addListener(future -> {
            if (future.isSuccess()) {
                System.out.println("端口[" + port + "]绑定成功!");
            } else {
                System.err.println("端口[" + port + "]绑定失败!");
            }
        });
    }
    
}

初始化channel

public class RpcChannelInitializer extends ChannelInitializer<NioSocketChannel> {
    
    @Override
    protected void initChannel(NioSocketChannel channel) throws Exception {
        ChannelPipeline line = channel.pipeline();
        line.addLast(new Spliter());
        line.addLast(new PacketCodecHandler());
        line.addLast(IMHandler.INSTANCE);
    }
}

三、Netty客户端创建

3.1 服务端业务处理器

/**
 * 获取服务端响应
 */
public class RpcClientHandler extends SimpleChannelInboundHandler<RpcInvocation> {

    private static final Logger logger = LoggerFactory.getLogger(RpcClientHandler.class);

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, RpcInvocation rpcInvocation) {
        logger.info("客户端获取到的响应消息为: {}", rpcInvocation.getResponse());
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        super.exceptionCaught(ctx, cause);
        Channel channel = ctx.channel();
        if (channel.isActive()) {
            ctx.close();
        }
    }

}

3.2 创建客户端

public class RpcClient {

    private Bootstrap bootstrap = new Bootstrap();

    private int MAX_RETRY = 3;

    public static Logger logger = LoggerFactory.getLogger(RpcClient.class);

    /**
     * 初始化客户端应用
     *
     * @return
     * @throws InterruptedException
     */
    public Bootstrap initClientApplication() {

        NioEventLoopGroup workerGroup = new NioEventLoopGroup();
        this.bootstrap
                //指定线程模型
                .group(workerGroup)
                //指定 IO 类型为 NIO
                .channel(NioSocketChannel.class)
                //表示连接的超时时间,超过这个时间还是建立不上的话则代表连接失败
                .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 5000)
                //表示是否开启 TCP 底层心跳机制,true 为开启
                .option(ChannelOption.SO_KEEPALIVE, true)
                //表示是否开始 Nagle 算法,true 表示关闭,false 表示开启
                .option(ChannelOption.TCP_NODELAY, true)
                //IO 处理逻辑
                .handler(new ChannelInitializer<SocketChannel>() {
                    @Override
                    public void initChannel(SocketChannel ch) {
                        //解码器
                        ch.pipeline().addLast(new Spliter());
                        ch.pipeline().addLast(new PacketCodecHandler());
                        // 收消息处理器
                        ch.pipeline().addLast(new RpcClientHandler());
                    }
                });
        return this.bootstrap;
    }

    /**
     * 客户端连接
     *
     * @param host      连接地址
     * @param port      端口号
     * @param retry     重试次数
     * @return
     */
    public ChannelFuture connect(String host, int port, int retry) throws InterruptedException {
        return this.bootstrap.connect(host, port).addListener(future -> {
            if (future.isSuccess()) {
                logger.info(new Date() + ": 连接成功");
            } else if (retry == 0) {
                logger.error("重试次数已用完,放弃连接!");
            } else {
                // 第几次重连
                int order = (MAX_RETRY - retry) + 1;
                // 本次重连的间隔
                int delay = 1 << order;
                logger.error(new Date() + ": 连接失败,第" + order + "次重连……");
                this.bootstrap.config().group().schedule(() -> connect(host, port, retry - 1), delay, TimeUnit
                        .SECONDS);
            }
        }).sync();
    }


}

四、测试

4.1启动服务端

public class RpcServerTest {

    private static final Logger logger = LoggerFactory.getLogger(RpcServerTest.class);

    @Test
    public void testServer() throws ExecutionException, InterruptedException {
        RpcServer rpcServer = new RpcServer();

        //创建客户端
        Callable<ChannelFuture> callable = new Callable<ChannelFuture>() {
            @Override
            public ChannelFuture call() throws Exception {
                //启动服务
                return rpcServer.startApplication(7777);
            }
        };
        Future<ChannelFuture> future = Executors.newFixedThreadPool(2).submit(callable);
        ChannelFuture channelFuture = future.get();
        Channel channel = channelFuture.channel();
        if (null == channel) throw new RuntimeException("netty server start error channel is null");
        while (!channel.isActive()) {
            logger.info("NettyServer启动服务 ...");
            Thread.sleep(500);
        }
        logger.info("NettyServer启动服务完成 {}", channel.localAddress());

        CountDownLatch countDownLatch = new CountDownLatch(1);
        countDownLatch.await();
    }

}

image.png

4.2 启动客户端发送消息

public class RpcClientTest {

    @Test
    public void testClient() throws InterruptedException {
        RpcClient rpcClient = new RpcClient();
        rpcClient.initClientApplication();
        ChannelFuture future = rpcClient.connect("127.0.0.1", 7777, 3);
        RpcInvocation rpcInvocation = new RpcInvocation();
        rpcInvocation.setReqMsg("你好啊!服务端");
        future.channel().writeAndFlush(rpcInvocation);
        CountDownLatch countDownLatch = new CountDownLatch(1);
        countDownLatch.await();
    }

}

image.png