高性能 Netty 之私有栈协议开发

1,480 阅读7分钟

前言

本文继续来将关于 Netty 建立私有栈协议的开发知识。本文讲解的顺序为:

  1. 什么是私有栈协议?
  2. 私有栈该具备什么功能?
  3. 私有栈的一般通信模型
  4. 私有栈的数据传输格式

什么是私有协议栈?

在通讯协议上,通信协议分为公有协议和私有协议。像我们在前几篇文章学的 Http / WebSocket,都算是公有协议,这些协议都为大众所熟知,并且有公共信赖的组织来制定标准。而私有协议呢,一般是用于公司或组织内部使用,或者是网络或用户接入使用。但是如果是外来的用户接入私有协议后就必须跟着这种非标准协议,才能够互联互通,否则不可能进入现行的网络。

私有栈的功能描述

一般来说,协议栈都需要具备最基础的功能是消息交互服务调用,所以那么基于 Netty 的协议栈可以具备的功能如下:

  1. 提供高性能的异步通信能力
  2. 提供消息的编解码框架,可以实现 POJO 的序列化和反序列化
  3. 提供基于 IP 低值的白名单接入认证机制
  4. 链路的有效性校验机制
  5. 链路的断连重连机制

通信模型

这里的通信模型指的是一个协议接入,传输信息以及断开的过程。

以上为概要过程,下面是具体的详细描述

  1. 客户端发起握手请求,携带有效的身份认证信息
  2. 服务端对客户端的身份进行校验,包括各种有效性以及信息合法性,然后返回握手应答请求
  3. 链路建立成功后,服务端可以给客户端发送业务消息;同时客户端也可以给服务端发送业务消息
  4. 链路建立成功后,客户端和服务端可以互发心跳消息
  5. 最后服务端退出后,关闭连接,客户都感知对方关闭连接后,被动关闭客户都安连接。

传输格式

之前我们学习过基于应用层协议 Http 的时候,我们可以发现它的传输格式由请求行/请求头部/请求数据三大块组成。所以我们制定私有协议的时候,也可以指定类似的格式。

这次我们的传输格式组成为 消息头 以及 消息体

代码实现

这次由于需要实现一个较为完整的 demo,所以涉及到的类会略多一点。下面会说明这些类的作用:

类说明

系统配置类

说明
MessageType消息类型
Constant常量类

实体结构

说明
Header消息头
Message消息体

编解码

说明
ChannelBufferByteInput缓冲字节输入
ChannelBufferByteOutput缓冲字节输出
MarshallingCodeFactory
MarshallingDecoderMarshal 解码器
MarshallingEncoderMarshal 编码器
MessageDecoder消息解码器
MessageEncoder消息编码器
TestCodec测试编解码

服务端和客户端

说明
HeartBeatRespHandler心跳响应处理器
LoginAuthRespHandler登录鉴权响应类
Server服务端
HeartBeatReqHandler心跳请求处理器
LoginAuthReqHandler登录鉴权请求类
Client客户端
Maven 依赖
        <dependency>
            <groupId>org.jboss.marshalling</groupId>
            <artifactId>jboss-marshalling</artifactId>
            <version>2.0.9.Final</version>
        </dependency>
        <dependency>
            <groupId>org.jboss.marshalling</groupId>
            <artifactId>jboss-marshalling-serial</artifactId>
            <version>2.0.9.Final</version>
        </dependency>
        <dependency>
            <groupId>io.netty</groupId>
            <artifactId>netty-all</artifactId>
            <version>4.1.51.Final</version>
        </dependency>
        <dependency>
            <groupId>log4j</groupId>
            <artifactId>log4j</artifactId>
            <version>1.2.17</version>
        </dependency>
        <dependency>
            <groupId>commons-logging</groupId>
            <artifactId>commons-logging</artifactId>
            <version>1.1.1</version>
        </dependency>
系统配置类

MessageType.java

public enum MessageType {
    SERVICE_REQ((byte) 0), SERVICE_RESP((byte) 1), ONE_WAY((byte) 2), LOGIN_REQ(
            (byte) 3), LOGIN_RESP((byte) 4), HEARTBEAT_REQ((byte) 5), HEARTBEAT_RESP(
            (byte) 6);

    private byte value;

    private MessageType(byte value) {
        this.value = value;
    }

    public byte value() {
        return this.value;
    }
}

Constant.java

public class Constant {
    public static final String REMOTEIP = "127.0.0.1";
    public static final int PORT = 8080;
    public static final int LOCAL_PORT = 12088;
    public static final String LOCALIP = "127.0.0.1";
}
实体结构

Header.java

public final class Header {
    private int crcCode = 0xabef0101;
    private int length;     //消息长度
    private long sessionID; //会话ID
    private byte type;      //消息类型
    private byte prority;   //优先级
    private Map<String, Object> attachment = new HashMap();

	//... 省略 getter 和 setter 方法
}

Message.java

public class Message {
    private Header header;
    private Object body;
	
    //... 省略 getter 和 setter 方法 
}
编解码

ChannelBufferByteInput.java

import io.netty.buffer.ByteBuf;
import org.jboss.marshalling.ByteInput;
import java.io.IOException;

/* channel 字节输入实现类 */
class ChannelBufferByteInput implements ByteInput {

    private final ByteBuf buffer;
	
    public ChannelBufferByteInput(ByteBuf buffer) {
        this.buffer = buffer;
    }

    @Override
    public void close() throws IOException {
        // nothing to do
    }

    @Override
    public int available() throws IOException {
        return buffer.readableBytes();
    }

    @Override
    public int read() throws IOException {
        if (buffer.isReadable()) {
            return buffer.readByte() & 0xff;
        }
        return -1;
    }

    @Override
    public int read(byte[] array) throws IOException {
        return read(array, 0, array.length);
    }

    @Override
    public int read(byte[] dst, int dstIndex, int length) throws IOException {
        int available = available();
        if (available == 0) {
            return -1;
        }

        length = Math.min(available, length);
        buffer.readBytes(dst, dstIndex, length);
        return length;
    }

    @Override
    public long skip(long bytes) throws IOException {
        int readable = buffer.readableBytes();
        if (readable < bytes) {
            bytes = readable;
        }
        buffer.readerIndex((int) (buffer.readerIndex() + bytes));
        return bytes;
    }

}

ChannelBufferByteOutput.java

import io.netty.buffer.ByteBuf;
import org.jboss.marshalling.ByteOutput;
import java.io.IOException;

/* channel 字节输出实现类 */
class ChannelBufferByteOutput implements ByteOutput {

    private final ByteBuf buffer;

    public ChannelBufferByteOutput(ByteBuf buffer) {
        this.buffer = buffer;
    }

    @Override
    public void close() throws IOException {
        // Nothing to do
    }

    @Override
    public void flush() throws IOException {
        // nothing to do
    }

    @Override
    public void write(int b) throws IOException {
        buffer.writeByte(b);
    }

    @Override
    public void write(byte[] bytes) throws IOException {
        buffer.writeBytes(bytes);
    }

    @Override
    public void write(byte[] bytes, int srcIndex, int length) throws IOException {
        buffer.writeBytes(bytes, srcIndex, length);
    }

    /**
     * Return the {@link ByteBuf} which contains the written content
     *
     */
    ByteBuf getBuffer() {
        return buffer;
    }
}

MarshallingCodeFactory.java

public final class MarshallingCodecFactory {
    /** 创建Jboss Marshaller */
    protected static Marshaller buildMarshalling() throws IOException {
        final MarshallerFactory marshallerFactory = Marshalling
            .getProvidedMarshallerFactory("serial");
        final MarshallingConfiguration configuration = new MarshallingConfiguration();
        configuration.setVersion(5);
        Marshaller marshaller = marshallerFactory
            .createMarshaller(configuration);
        return marshaller;
    }

    /** 创建Jboss Unmarshaller */
    protected static Unmarshaller buildUnMarshalling() throws IOException {
        final MarshallerFactory marshallerFactory = Marshalling
                    .getProvidedMarshallerFactory("serial");
        final MarshallingConfiguration configuration = new MarshallingConfiguration();
                configuration.setVersion(5);
        final Unmarshaller unmarshaller = marshallerFactory
            .createUnmarshaller(configuration);
        return unmarshaller;
    }
}

MarshallingDecoder.java

public class MarshallingDecoder {

    private final Unmarshaller unmarshaller;

    public MarshallingDecoder() throws IOException {
		unmarshaller = MarshallingCodecFactory.buildUnMarshalling();
    }

    protected Object decode(ByteBuf in) throws Exception {
        int objectSize = in.readInt();
        ByteBuf buf = in.slice(in.readerIndex(), objectSize);
        ByteInput input = new ChannelBufferByteInput(buf);
        try {
            unmarshaller.start(input);
            Object obj = unmarshaller.readObject();
            unmarshaller.finish();
            in.readerIndex(in.readerIndex() + objectSize);
            return obj;
        } finally {
            unmarshaller.close();
        }
    }
}

MarshallingEncoder.java

@Sharable
public class MarshallingEncoder {

    private static final byte[] LENGTH_PLACEHOLDER = new byte[4];
    Marshaller marshaller;

    public MarshallingEncoder() throws IOException {
		marshaller = MarshallingCodecFactory.buildMarshalling();
    }

    protected void encode(Object msg, ByteBuf out) throws Exception {
        try {
        	// 写入编码信息
            int lengthPos = out.writerIndex();
            out.writeBytes(LENGTH_PLACEHOLDER);
            ChannelBufferByteOutput output = new ChannelBufferByteOutput(out);
            marshaller.start(output);
            marshaller.writeObject(msg);
            marshaller.finish();
            out.setInt(lengthPos, out.writerIndex() - lengthPos - 4);
        } finally {
            marshaller.close();
        }
    }
}

MessageDecoder.java

public class MessageDecoder extends LengthFieldBasedFrameDecoder {

    MarshallingDecoder marshallingDecoder;

    public MessageDecoder(int maxFrameLength, int lengthFieldOffset,
	    int lengthFieldLength) throws IOException {
      super(maxFrameLength, lengthFieldOffset, lengthFieldLength);
      marshallingDecoder = new MarshallingDecoder();
    }

    @Override
    protected Object decode(ChannelHandlerContext ctx, ByteBuf in)
	    throws Exception {
        ByteBuf frame = (ByteBuf) super.decode(ctx, in);
        if (frame == null) {
            return null;
        }

        Message message = new Message();
        Header header = new Header();
        header.setCrcCode(frame.readInt());
        header.setLength(frame.readInt());
        header.setSessionID(frame.readLong());
        header.setType(frame.readByte());
        header.setPriority(frame.readByte());

        int size = frame.readInt();
        if (size > 0) {
            Map<String, Object> attch = new HashMap<String, Object>(size);
            int keySize = 0;
            byte[] keyArray = null;
            String key = null;
            for (int i = 0; i < size; i++) {
                keySize = frame.readInt();
                keyArray = new byte[keySize];
                frame.readBytes(keyArray);
                key = new String(keyArray, "UTF-8");
                attch.put(key, marshallingDecoder.decode(frame));
            }
            keyArray = null;
            key = null;
            header.setAttachment(attch);
        }
        if (frame.readableBytes() > 4) {
            message.setBody(marshallingDecoder.decode(frame));
        }
        message.setHeader(header);
        return message;
    }
}

MessageEncoder.java

public final class MessageEncoder extends
	MessageToByteEncoder<Message> {

    MarshallingEncoder marshallingEncoder;

    public MessageEncoder() throws IOException {
		this.marshallingEncoder = new MarshallingEncoder();
    }

    @Override
    protected void encode(ChannelHandlerContext ctx, Message msg,
	    ByteBuf sendBuf) throws Exception {
        if (msg == null || msg.getHeader() == null)
            throw new Exception("The encode message is null");
        sendBuf.writeInt((msg.getHeader().getCrcCode()));
        sendBuf.writeInt((msg.getHeader().getLength()));
        sendBuf.writeLong((msg.getHeader().getSessionID()));
        sendBuf.writeByte((msg.getHeader().getType()));
        sendBuf.writeByte((msg.getHeader().getPriority()));
        sendBuf.writeInt((msg.getHeader().getAttachment().size()));
        String key = null;
        byte[] keyArray = null;
        Object value = null;
        for (Map.Entry<String, Object> param : msg.getHeader().getAttachment()
            .entrySet()) {
            key = param.getKey();
            keyArray = key.getBytes("UTF-8");
            sendBuf.writeInt(keyArray.length);
            sendBuf.writeBytes(keyArray);
            value = param.getValue();
            marshallingEncoder.encode(value, sendBuf);
        }
        key = null;
        keyArray = null;
        value = null;
        if (msg.getBody() != null) {
            marshallingEncoder.encode(msg.getBody(), sendBuf);
        } else
            sendBuf.writeInt(0);
        sendBuf.setInt(4, sendBuf.readableBytes() - 8);
    }
}
服务端和客户端

服务端 Server.java

public class Server {

	private static final Log LOG = LogFactory.getLog(Server.class);

    public void bind() throws Exception {
        // 配置服务端的NIO线程组
        EventLoopGroup bossGroup = new NioEventLoopGroup();
        EventLoopGroup workerGroup = new NioEventLoopGroup();
        ServerBootstrap b = new ServerBootstrap();
        b.group(bossGroup, workerGroup).channel(NioServerSocketChannel.class)
            .option(ChannelOption.SO_BACKLOG, 100)
            .handler(new LoggingHandler(LogLevel.INFO))
            .childHandler(new ChannelInitializer<SocketChannel>() {
                @Override
                public void initChannel(SocketChannel ch)
                    throws IOException {
                ch.pipeline().addLast(
                    new MessageDecoder(1024 * 1024, 4, 4));
                ch.pipeline().addLast(new MessageEncoder());
                ch.pipeline().addLast("readTimeoutHandler",
                    new ReadTimeoutHandler(50));
                ch.pipeline().addLast(new LoginAuthRespHandler());
                ch.pipeline().addLast("HeartBeatHandler",
                    new HeartBeatRespHandler());
                }
            });

        // 绑定端口,同步等待成功
        b.bind(Constant.REMOTEIP, Constant.PORT).sync();
        LOG.info("server start ok : "
            + (Constant.REMOTEIP + " : " + Constant.PORT));
    }

    public static void main(String[] args) throws Exception {
		new Server().bind();
    }
}

HeartBeatRespHandler.java

public class HeartBeatRespHandler extends ChannelHandlerAdapter {

	private static final Log LOG = LogFactory.getLog(HeartBeatRespHandler.class);
    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg)
	    throws Exception {
        Message message = (Message) msg;
        // 返回心跳应答消息
        if (message.getHeader() != null
            && message.getHeader().getType() == MessageType.HEARTBEAT_REQ
                .value()) {
            LOG.info("Receive client heart beat message : ---> "
                + message);
            Message heartBeat = buildHeatBeat();
            LOG.info("Send heart beat response message to client : ---> "
                    + heartBeat);
            ctx.writeAndFlush(heartBeat);
        } else
            ctx.fireChannelRead(msg);
    }
	//心跳构造器
    private Message buildHeatBeat() {
        Message message = new Message();
        Header header = new Header();
        header.setType(MessageType.HEARTBEAT_RESP.value());
        message.setHeader(header);
        return message;
    }

}

LoginAuthRespHandler.java

public class LoginAuthRespHandler extends ChannelHandlerAdapter {

	private final static Log LOG = LogFactory.getLog(LoginAuthRespHandler.class);
  	//缓存框架,用于维护是否登录
    private Map<String, Boolean> nodeCheck = new ConcurrentHashMap<String, Boolean>();
    private String[] whitekList = { "127.0.0.1", "192.168.1.104" };

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg)
	    throws Exception {
        Message message = (Message) msg;

        // 如果是握手请求消息,处理,其它消息透传
        if (message.getHeader() != null
            && message.getHeader().getType() == MessageType.LOGIN_REQ
                .value()) {
            String nodeIndex = ctx.channel().remoteAddress().toString();
            Message loginResp = null;
            // 重复登陆,拒绝
            if (nodeCheck.containsKey(nodeIndex)) {
            loginResp = buildResponse((byte) -1);
            } else {
            InetSocketAddress address = (InetSocketAddress) ctx.channel()
                .remoteAddress();
            String ip = address.getAddress().getHostAddress();
            boolean isOK = false;
            for (String WIP : whitekList) {
                if (WIP.equals(ip)) {
                isOK = true;
                break;
                }
            }
            loginResp = isOK ? buildResponse((byte) 0)
                : buildResponse((byte) -1);
            if (isOK)
                nodeCheck.put(nodeIndex, true);
            }
            LOG.info("The login response is : " + loginResp
                + " body [" + loginResp.getBody() + "]");
            ctx.writeAndFlush(loginResp);
        } else {
            ctx.fireChannelRead(msg);
        }
    }

    private Message buildResponse(byte result) {
		Message message = new Message();
        Header header = new Header();
        header.setType(MessageType.LOGIN_RESP.value());
        message.setHeader(header);
        message.setBody(result);
        return message;
    }

    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
	    throws Exception {
        cause.printStackTrace();
        nodeCheck.remove(ctx.channel().remoteAddress().toString());// 删除缓存
        ctx.close();
        ctx.fireExceptionCaught(cause);
    }
}

客户端 Client.java

public class Client {
    private static final Log LOG = LogFactory.getLog(Client.class);
    private ScheduledExecutorService executor = Executors
            .newScheduledThreadPool(1);
    EventLoopGroup group = new NioEventLoopGroup();

    public void connect(int port, String host) throws Exception {
        // 配置客户端NIO线程组
        try {
            Bootstrap b = new Bootstrap();
            b.group(group).channel(NioSocketChannel.class)
                    .option(ChannelOption.TCP_NODELAY, true)
                    .handler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        public void initChannel(SocketChannel ch)
                                throws Exception {
                            ch.pipeline().addLast(
                                    new MessageDecoder(1024 * 1024, 4, 4));
                            ch.pipeline().addLast("MessageEncoder",
                                    new MessageEncoder());
                            ch.pipeline().addLast("readTimeoutHandler",
                                    new ReadTimeoutHandler(50));
                            ch.pipeline().addLast("LoginAuthHandler",
                                    new LoginAuthReqHandler());
                            ch.pipeline().addLast("HeartBeatHandler",
                                    new HeartBeatReqHandler());
                        }
                    });
            // 发起异步连接操作
            ChannelFuture future = b.connect(
                    new InetSocketAddress(host, port),
                    new InetSocketAddress(Constant.LOCALIP,
                            Constant.LOCAL_PORT)).sync();
            // 当对应的channel关闭的时候,就会返回对应的channel。
            future.channel().closeFuture().sync();
        } finally {
            // 所有资源释放完成之后,清空资源,再次发起重连操作
            executor.execute(new Runnable() {
                @Override
                public void run() {
                    try {
                        TimeUnit.SECONDS.sleep(1);
                        try {
                            connect(Constant.PORT, Constant.REMOTEIP);// 发起重连操作
                        } catch (Exception e) {
                            e.printStackTrace();
                        }
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            });
        }
    }

    public static void main(String[] args) throws Exception {
        new NettyClient().connect(Constant.PORT, Constant.REMOTEIP);
    }

}

HeartBeatReqHandler.java

public class HeartBeatReqHandler extends ChannelHandlerAdapter {

    private static final Log LOG = LogFactory.getLog(HeartBeatReqHandler.class);

    private volatile ScheduledFuture<?> heartBeat;

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg)
            throws Exception {
       	Message message = (Message) msg;
        // 握手成功,主动发送心跳消息
        if (message.getHeader() != null
                && message.getHeader().getType() == MessageType.LOGIN_RESP
                .value()) {
            heartBeat = ctx.executor().scheduleAtFixedRate(
                    new HeartBeatReqHandler.HeartBeatTask(ctx), 0, 5000,
                    TimeUnit.MILLISECONDS);
        } else if (message.getHeader() != null
                && message.getHeader().getType() == MessageType.HEARTBEAT_RESP
                .value()) {
            LOG.info("Client receive server heart beat message : ---> "
                            + message);
        } else
            ctx.fireChannelRead(msg);
    }

    private class HeartBeatTask implements Runnable {
        private final ChannelHandlerContext ctx;

        public HeartBeatTask(final ChannelHandlerContext ctx) {
            this.ctx = ctx;
        }

        @Override
        public void run() {
            Message heatBeat = buildHeatBeat();
            LOG.info("Client send heart beat messsage to server : ---> "
                            + heatBeat);
            ctx.writeAndFlush(heatBeat);
        }

        private Message buildHeatBeat() {
            Message message = new Message();
            Header header = new Header();
            header.setType(MessageType.HEARTBEAT_REQ.value());
            message.setHeader(header);
            return message;
        }
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
            throws Exception {
        cause.printStackTrace();
        if (heartBeat != null) {
            heartBeat.cancel(true);
            heartBeat = null;
        }
        ctx.fireExceptionCaught(cause);
    }
}

LoginAuthReqHandler.java

public class LoginAuthReqHandler extends ChannelHandlerAdapter {

    private static final Log LOG = LogFactory.getLog(LoginAuthReqHandler.class);

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        ctx.writeAndFlush(buildLoginReq());
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg)
            throws Exception {
        Message message = (Message) msg;

        // 如果是握手应答消息,需要判断是否认证成功
        if (message.getHeader() != null
                && message.getHeader().getType() == MessageType.LOGIN_RESP
                .value()) {
            byte loginResult = (byte) message.getBody();
            if (loginResult != (byte) 0) {
                // 握手失败,关闭连接
                ctx.close();
            } else {
                LOG.info("Login is ok : " + message);
                ctx.fireChannelRead(msg);
            }
        } else
            ctx.fireChannelRead(msg);
    }
	//构造登录请求
    private Message buildLoginReq() {
        Message message = new Message();
        Header header = new Header();
        header.setType(MessageType.LOGIN_REQ.value());
        message.setHeader(header);
        return message;
    }
	//异常跑错
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
            throws Exception {
        ctx.fireExceptionCaught(cause);
    }
}

结语

使用 Netty 搭建私有栈的时候,需要考虑很多可靠性方面的功能。例如说,我们在使用 Http 应用层协议的时候,表面看似很简单,其实背地里需要很多措施和功能在支撑着。所以像我们这种私有的协议栈,可能更多需要考虑性能,可用等因素,如链路断连的情况下消息究竟是丢弃还是重发;我们需要更加完善的编解码器;超时操作,自定义定时任务;安全认证等等。

完!