netty观察生产和消费者模式

130 阅读3分钟

1.代理生成类

总:

​ Dispatcher

​ MyProxy

分:

package io.test.rpc;

import java.util.concurrent.ConcurrentHashMap;

public class Dispatcher {

    private Dispatcher(){}

    private static Dispatcher dispatcher;

    static {
        dispatcher = new Dispatcher();
    }

    public static Dispatcher getInstance(){
        return dispatcher;
    }

    private static ConcurrentHashMap<String, Object> invokeMap = new ConcurrentHashMap<>();

    public void register(String key, Object obj){
        invokeMap.put(key, obj);
    }

    public Object get(String key){
        return invokeMap.get(key);
    }

}
package io.test.proxy;

import io.test.rpc.Dispatcher;
import io.test.rpc.protocol.MyContent;
import io.test.rpc.transport.ClientFactory;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.concurrent.CompletableFuture;

public class MyProxy {

    static Dispatcher dispatcher = Dispatcher.getInstance();

    public static <T> T proxyGet(Class<T> interfaceInfo){
        ClassLoader classLoader = interfaceInfo.getClassLoader();
        String name = interfaceInfo.getName();
        Class<?>[] interfaces = {interfaceInfo};
        return (T) Proxy.newProxyInstance(classLoader, interfaces, new InvocationHandler() {
            @Override
            public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
                Object res = null;
                Object o = dispatcher.get(name);
                if (null == o){
                    String methodName = method.getName();
                    Class<?>[] parameterTypes = method.getParameterTypes();
                    MyContent content = new MyContent();
                    content.setArgs(args);
                    content.setName(name);
                    content.setMethodName(methodName);
                    content.setParameterTypes(parameterTypes);
                    CompletableFuture transport = ClientFactory.transport(content);
                    res = transport.get();//阻塞的
                } else {
                    //local调用
                    Class<?> clazz = o.getClass();
                    try {
                        //做一些度量统计的插件
                        Method m = clazz.getMethod(method.getName(), method.getParameterTypes());
                        res = m.invoke(o, args);
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                }
                return res;
            }
        });
    }
}

2.Client发送端

总:

​ ClientFactory

​ ClientPool

​ ClientResponses

​ ServerDecode

分:

package io.test.rpc.transport;

import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.test.rpc.ResponseMappingCallback;
import io.test.rpc.protocol.MyContent;
import io.test.rpc.protocol.MyHeader;
import io.test.utils.SerDerUtil;
import lombok.extern.slf4j.Slf4j;
import java.net.InetSocketAddress;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;

@Slf4j
public class ClientFactory {

    private ClientFactory() {}

    private static final ClientFactory clientFactory;

    static{
        clientFactory = new ClientFactory();
    }

    public static ClientFactory getInstance(){
        return clientFactory;
    }

    public static CompletableFuture transport(MyContent content){
        byte[] msgBody = SerDerUtil.ser(content);
        MyHeader myHeader = MyHeader.createMyHeader(msgBody);
        byte[] msgHeader = SerDerUtil.ser(myHeader);
        // 观察发送的header长度。设置Constant.headerSize
        log.info("header size : {}" , msgHeader.length);
        NioSocketChannel clientChannel = clientFactory.getClient(new InetSocketAddress("localhost", 9090));
        ByteBuf byteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(msgHeader.length + msgBody.length);
        long id = myHeader.getRequestID();
        CompletableFuture res = new CompletableFuture<>();
        ResponseMappingCallback.addCallBack(id, res);
        byteBuf.writeBytes(msgHeader);
        byteBuf.writeBytes(msgBody);
        clientChannel.writeAndFlush(byteBuf);
        return res;
    }

    ConcurrentHashMap<InetSocketAddress, ClientPool> boxes = new ConcurrentHashMap<>();

    Random rand = new Random();

    int poolSize = 5;

    private NioSocketChannel getClient(InetSocketAddress inetSocketAddress){
        ClientPool clientPool = boxes.get(inetSocketAddress);
        if (null == clientPool){
            synchronized (boxes){
                if (null == clientPool){
                    boxes.putIfAbsent(inetSocketAddress, new ClientPool(poolSize));
                    clientPool = boxes.get(inetSocketAddress);
                }
            }
        }
        int i = rand.nextInt(poolSize);
        if (null != clientPool.clients[i] && clientPool.clients[i].isActive()) {
            return clientPool.clients[i];
        } else {
            synchronized (clientPool.lock[i]){
                if (null == clientPool.clients[i] || !clientPool.clients[i].isActive()){
                    return clientPool.clients[i] = createClient(inetSocketAddress);
                }
            }
        }
        return clientPool.clients[i];
    }

    private NioEventLoopGroup clientWorker;

    public NioSocketChannel createClient(InetSocketAddress inetSocketAddress){
        clientWorker = new NioEventLoopGroup(1);
        Bootstrap bootstrap = new Bootstrap();
        ChannelFuture future = bootstrap
                .group(clientWorker)
                .channel(NioSocketChannel.class)
                .handler(new ChannelInitializer<NioSocketChannel>() {
                    @Override
                    protected void initChannel(NioSocketChannel nioSocketChannel) throws Exception {
                        ChannelPipeline pipeline = nioSocketChannel.pipeline();
                        pipeline.addLast(new ServerDecode());
                        pipeline.addLast(new ClientResponses());
                    }
                }).connect(inetSocketAddress);
        try {
            NioSocketChannel client = (NioSocketChannel)future.sync().channel();
            return client;
        } catch (InterruptedException e) {
            log.error("", e);
            return null;
        }
    }
}
package io.test.rpc.transport;

import io.netty.channel.socket.nio.NioSocketChannel;

public class ClientPool {
    NioSocketChannel[] clients;
    Object[] lock;

    ClientPool(int size){
        clients = new NioSocketChannel[size];
        lock = new Object[size];
        for (int i = 0; i < size; i++){
            lock[i] = new Object();
        }
    }
}
package io.test.rpc.transport;

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.test.rpc.ResponseMappingCallback;
import io.test.utils.Packmsg;

public class ClientResponses extends ChannelInboundHandlerAdapter {
    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        Packmsg packmsg = (Packmsg) msg;
        ResponseMappingCallback.runCallBack(packmsg);
    }
}
package io.test.rpc.transport;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.test.rpc.protocol.MyContent;
import io.test.rpc.protocol.MyHeader;
import io.test.utils.Constant;
import io.test.utils.Packmsg;
import io.test.utils.SerDerUtil;
import java.util.List;

public class ServerDecode extends ByteToMessageDecoder {

    @Override
    protected void decode(ChannelHandlerContext channelHandlerContext, ByteBuf byteBuf, List<Object> list) throws Exception {
        while(byteBuf.readableBytes() > Constant.headerSize){
            byte[] bytes = new byte[Constant.headerSize];
            byteBuf.getBytes(byteBuf.readerIndex(), bytes);
            MyHeader myHeader = (MyHeader) SerDerUtil.read(bytes);
            if (byteBuf.readableBytes() >= myHeader.getDataLen()){
                byteBuf.readBytes(Constant.headerSize);
                byte[] data = new byte[(int) myHeader.getDataLen()];
                byteBuf.readBytes(data);
                Object obj = SerDerUtil.read(data);
                if (myHeader.getFlag() == 0X14141414) {
                    list.add(new Packmsg(myHeader, (MyContent) obj));
                } else if (myHeader.getFlag() == 0X14141424) {
                    list.add(new Packmsg(myHeader, (MyContent) obj));
                }
            } else {
                break;
            }
        }
    }
}

3.Server服务端

总:

​ ServerRequestHandler

分:

package io.test.rpc.transport;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.test.rpc.Dispatcher;
import io.test.rpc.protocol.MyContent;
import io.test.rpc.protocol.MyHeader;
import io.test.utils.Packmsg;
import io.test.utils.SerDerUtil;
import java.lang.reflect.Method;

public class ServerRequestHandler extends ChannelInboundHandlerAdapter {

    Dispatcher dispatcher;

    public ServerRequestHandler(Dispatcher dispatcher){
        this.dispatcher = dispatcher;
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        Packmsg packmsg = (Packmsg) msg;
        ctx.executor().execute(new Runnable() {
            @Override
            public void run() {
                MyContent content = packmsg.getContent();
                String serviceName = content.getName();
                String methodName = content.getMethodName();
                Object[] args = content.getArgs();
                Class<?>[] parameterTypes = content.getParameterTypes();
                Object o = dispatcher.get(serviceName);
                Class<?> clazz = o.getClass();
                Object res = null;
                try {
                    Method m = clazz.getMethod(methodName, parameterTypes);
                    res = m.invoke(o, args);
                } catch (Exception e) {
                    e.printStackTrace();
                }
                MyContent resContent = new MyContent();
                resContent.setRes(res);
                byte[] contentByte = SerDerUtil.ser(resContent);
                MyHeader resHeader = new MyHeader();
                resHeader.setFlag(0X14141424);
                resHeader.setRequestID(packmsg.getHeader().getRequestID());
                resHeader.setDataLen(contentByte.length);
                byte[] headerByte = SerDerUtil.ser(resHeader);
                ByteBuf byteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(headerByte.length + contentByte.length);
                byteBuf.writeBytes(headerByte);
                byteBuf.writeBytes(contentByte);
                ctx.writeAndFlush(byteBuf);
            }
        });
    }
}

4.Future异步

package io.test.rpc;

import io.test.utils.Packmsg;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;

public class ResponseMappingCallback {

    static ConcurrentHashMap<Long, CompletableFuture> mapping = new ConcurrentHashMap();

    public static void addCallBack(long requestID, CompletableFuture future){
        mapping.putIfAbsent(requestID, future);
    }

    public static void runCallBack(Packmsg packmsg){
        CompletableFuture completableFuture = mapping.get(packmsg.getHeader().getRequestID());
        completableFuture.complete(packmsg.getContent().getRes());
        removeCallBack(packmsg.getHeader().getRequestID());
    }

    public static void removeCallBack(long requestID){
        mapping.remove(requestID);
    }
}

5.协议

总:

​ MyContent

​ MyHeader

分:

package io.test.rpc.protocol;

import lombok.Data;
import lombok.ToString;
import java.io.Serializable;
import java.util.UUID;

/**
 * @author jz
 */
@Data
@ToString
public class MyHeader implements Serializable {
    private int flag;
    private long requestID;
    private long dataLen;

    public static MyHeader createMyHeader(byte[] msg){
        MyHeader header = new MyHeader();
        int length = msg.length;
        int flag = 0X14141414;
        long requestID = Math.abs(UUID.randomUUID().getLeastSignificantBits());
        header.setFlag(flag);
        header.setRequestID(requestID);
        header.setDataLen(length);
        return header;
    }
}
package io.test.rpc.protocol;

import lombok.Data;
import lombok.ToString;
import java.io.Serializable;

/**
 * @author jz
 */
@Data
@ToString
public class MyContent implements Serializable {

    private String name;

    private String methodName;

    private Class<?>[] parameterTypes;

    private Object[] args;

    private Object res;
}

6.服务service

package io.test.service;

public interface Car {

    public String drive(String msg);
}
package io.test.service;

public class MyCar implements Car{
    @Override
    public String drive(String msg) {
        return "开" + msg;
    }
}
package io.test.service;

public interface Fly {
    void fly(String msg);
}
package io.test.service;

public class MyFly implements Fly{
    @Override
    public void fly(String msg) {
        System.out.println("草丛飞!" + msg);
    }
}

7.工具类

package io.test.utils;

public class Constant {

    // ClientFactory.transport
    public static final int headerSize = 99;

}
package io.test.utils;

import io.test.rpc.protocol.MyContent;
import io.test.rpc.protocol.MyHeader;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;

/**
 * @author jz
 */
@AllArgsConstructor
@NoArgsConstructor
@Data
@ToString
public class Packmsg {

    private MyHeader header;

    private MyContent content;
}
package io.test.utils;

import java.io.*;

public class SerDerUtil {

    static ByteArrayOutputStream out = new ByteArrayOutputStream();

    public synchronized static byte[] ser(Object msg){
        out.reset();
        ObjectOutputStream oOut = null;
        byte[] msgBody = null;
        try {
            oOut = new ObjectOutputStream(out);
            oOut.writeObject(msg);
            msgBody = out.toByteArray();
        } catch (IOException e) {
            e.printStackTrace();
        }
        return msgBody;
    }

    public static Object read(byte[] data) throws IOException, ClassNotFoundException {
        ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(data);
        ObjectInputStream objectInputStream = new ObjectInputStream(byteArrayInputStream);
        return objectInputStream.readObject();
    }
}

8.最终执行

package io.test;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.test.proxy.MyProxy;
import io.test.rpc.Dispatcher;
import io.test.rpc.transport.ServerDecode;
import io.test.rpc.transport.ServerRequestHandler;
import io.test.service.Car;
import io.test.service.Fly;
import io.test.service.MyCar;
import io.test.service.MyFly;
import org.junit.Test;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.concurrent.atomic.AtomicInteger;

public class RpcTest {

    @Test
    public void startServer(){
        MyCar car = new MyCar();
        MyFly fly = new MyFly();
        Dispatcher dis = Dispatcher.getInstance();
        dis.register(Car.class.getName(),car);
        dis.register(Fly.class.getName(),fly);
        NioEventLoopGroup boss = new NioEventLoopGroup(1);
        NioEventLoopGroup worker =  boss;
        ServerBootstrap sbs = new ServerBootstrap();
        ChannelFuture bind = sbs.group(boss, worker)
                .channel(NioServerSocketChannel.class)
                .childHandler(new ChannelInitializer<NioSocketChannel>() {
                    @Override
                    protected void initChannel(NioSocketChannel ch) throws Exception {
                        System.out.println("server accept cliet port: "+ ch.remoteAddress().getPort());
                        ChannelPipeline p = ch.pipeline();
                        p.addLast(new ServerDecode());
                        p.addLast(new ServerRequestHandler(dis));
                    }
                }).bind(new InetSocketAddress("localhost", 9090));
        try {
            bind.sync().channel().closeFuture().sync();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

    @Test
    public void get(){
        AtomicInteger num = new AtomicInteger(0);
        int size = 50;
        Thread[] threads = new Thread[size];
        for (int i = 0; i <size; i++) {
            threads[i] = new Thread(()->{
                Car car = MyProxy.proxyGet(Car.class);//动态代理实现
                String arg = "宝马X" + num.incrementAndGet();
                String res = car.drive(arg);
                System.out.println("client over msg: " + res+" src arg: "+ arg);
            });
        }
        for (Thread thread : threads) {
            thread.start();
        }
        try {
            System.in.read();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}