1.代理生成类
总:
Dispatcher
MyProxy
分:
package io.test.rpc;
import java.util.concurrent.ConcurrentHashMap;
public class Dispatcher {
private Dispatcher(){}
private static Dispatcher dispatcher;
static {
dispatcher = new Dispatcher();
}
public static Dispatcher getInstance(){
return dispatcher;
}
private static ConcurrentHashMap<String, Object> invokeMap = new ConcurrentHashMap<>();
public void register(String key, Object obj){
invokeMap.put(key, obj);
}
public Object get(String key){
return invokeMap.get(key);
}
}
package io.test.proxy;
import io.test.rpc.Dispatcher;
import io.test.rpc.protocol.MyContent;
import io.test.rpc.transport.ClientFactory;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.concurrent.CompletableFuture;
public class MyProxy {
static Dispatcher dispatcher = Dispatcher.getInstance();
public static <T> T proxyGet(Class<T> interfaceInfo){
ClassLoader classLoader = interfaceInfo.getClassLoader();
String name = interfaceInfo.getName();
Class<?>[] interfaces = {interfaceInfo};
return (T) Proxy.newProxyInstance(classLoader, interfaces, new InvocationHandler() {
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
Object res = null;
Object o = dispatcher.get(name);
if (null == o){
String methodName = method.getName();
Class<?>[] parameterTypes = method.getParameterTypes();
MyContent content = new MyContent();
content.setArgs(args);
content.setName(name);
content.setMethodName(methodName);
content.setParameterTypes(parameterTypes);
CompletableFuture transport = ClientFactory.transport(content);
res = transport.get();//阻塞的
} else {
//local调用
Class<?> clazz = o.getClass();
try {
//做一些度量统计的插件
Method m = clazz.getMethod(method.getName(), method.getParameterTypes());
res = m.invoke(o, args);
} catch (Exception e) {
e.printStackTrace();
}
}
return res;
}
});
}
}
2.Client发送端
总:
ClientFactory
ClientPool
ClientResponses
ServerDecode
分:
package io.test.rpc.transport;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.test.rpc.ResponseMappingCallback;
import io.test.rpc.protocol.MyContent;
import io.test.rpc.protocol.MyHeader;
import io.test.utils.SerDerUtil;
import lombok.extern.slf4j.Slf4j;
import java.net.InetSocketAddress;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
@Slf4j
public class ClientFactory {
private ClientFactory() {}
private static final ClientFactory clientFactory;
static{
clientFactory = new ClientFactory();
}
public static ClientFactory getInstance(){
return clientFactory;
}
public static CompletableFuture transport(MyContent content){
byte[] msgBody = SerDerUtil.ser(content);
MyHeader myHeader = MyHeader.createMyHeader(msgBody);
byte[] msgHeader = SerDerUtil.ser(myHeader);
// 观察发送的header长度。设置Constant.headerSize
log.info("header size : {}" , msgHeader.length);
NioSocketChannel clientChannel = clientFactory.getClient(new InetSocketAddress("localhost", 9090));
ByteBuf byteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(msgHeader.length + msgBody.length);
long id = myHeader.getRequestID();
CompletableFuture res = new CompletableFuture<>();
ResponseMappingCallback.addCallBack(id, res);
byteBuf.writeBytes(msgHeader);
byteBuf.writeBytes(msgBody);
clientChannel.writeAndFlush(byteBuf);
return res;
}
ConcurrentHashMap<InetSocketAddress, ClientPool> boxes = new ConcurrentHashMap<>();
Random rand = new Random();
int poolSize = 5;
private NioSocketChannel getClient(InetSocketAddress inetSocketAddress){
ClientPool clientPool = boxes.get(inetSocketAddress);
if (null == clientPool){
synchronized (boxes){
if (null == clientPool){
boxes.putIfAbsent(inetSocketAddress, new ClientPool(poolSize));
clientPool = boxes.get(inetSocketAddress);
}
}
}
int i = rand.nextInt(poolSize);
if (null != clientPool.clients[i] && clientPool.clients[i].isActive()) {
return clientPool.clients[i];
} else {
synchronized (clientPool.lock[i]){
if (null == clientPool.clients[i] || !clientPool.clients[i].isActive()){
return clientPool.clients[i] = createClient(inetSocketAddress);
}
}
}
return clientPool.clients[i];
}
private NioEventLoopGroup clientWorker;
public NioSocketChannel createClient(InetSocketAddress inetSocketAddress){
clientWorker = new NioEventLoopGroup(1);
Bootstrap bootstrap = new Bootstrap();
ChannelFuture future = bootstrap
.group(clientWorker)
.channel(NioSocketChannel.class)
.handler(new ChannelInitializer<NioSocketChannel>() {
@Override
protected void initChannel(NioSocketChannel nioSocketChannel) throws Exception {
ChannelPipeline pipeline = nioSocketChannel.pipeline();
pipeline.addLast(new ServerDecode());
pipeline.addLast(new ClientResponses());
}
}).connect(inetSocketAddress);
try {
NioSocketChannel client = (NioSocketChannel)future.sync().channel();
return client;
} catch (InterruptedException e) {
log.error("", e);
return null;
}
}
}
package io.test.rpc.transport;
import io.netty.channel.socket.nio.NioSocketChannel;
public class ClientPool {
NioSocketChannel[] clients;
Object[] lock;
ClientPool(int size){
clients = new NioSocketChannel[size];
lock = new Object[size];
for (int i = 0; i < size; i++){
lock[i] = new Object();
}
}
}
package io.test.rpc.transport;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.test.rpc.ResponseMappingCallback;
import io.test.utils.Packmsg;
public class ClientResponses extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
Packmsg packmsg = (Packmsg) msg;
ResponseMappingCallback.runCallBack(packmsg);
}
}
package io.test.rpc.transport;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.test.rpc.protocol.MyContent;
import io.test.rpc.protocol.MyHeader;
import io.test.utils.Constant;
import io.test.utils.Packmsg;
import io.test.utils.SerDerUtil;
import java.util.List;
public class ServerDecode extends ByteToMessageDecoder {
@Override
protected void decode(ChannelHandlerContext channelHandlerContext, ByteBuf byteBuf, List<Object> list) throws Exception {
while(byteBuf.readableBytes() > Constant.headerSize){
byte[] bytes = new byte[Constant.headerSize];
byteBuf.getBytes(byteBuf.readerIndex(), bytes);
MyHeader myHeader = (MyHeader) SerDerUtil.read(bytes);
if (byteBuf.readableBytes() >= myHeader.getDataLen()){
byteBuf.readBytes(Constant.headerSize);
byte[] data = new byte[(int) myHeader.getDataLen()];
byteBuf.readBytes(data);
Object obj = SerDerUtil.read(data);
if (myHeader.getFlag() == 0X14141414) {
list.add(new Packmsg(myHeader, (MyContent) obj));
} else if (myHeader.getFlag() == 0X14141424) {
list.add(new Packmsg(myHeader, (MyContent) obj));
}
} else {
break;
}
}
}
}
3.Server服务端
总:
ServerRequestHandler
分:
package io.test.rpc.transport;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.test.rpc.Dispatcher;
import io.test.rpc.protocol.MyContent;
import io.test.rpc.protocol.MyHeader;
import io.test.utils.Packmsg;
import io.test.utils.SerDerUtil;
import java.lang.reflect.Method;
public class ServerRequestHandler extends ChannelInboundHandlerAdapter {
Dispatcher dispatcher;
public ServerRequestHandler(Dispatcher dispatcher){
this.dispatcher = dispatcher;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
Packmsg packmsg = (Packmsg) msg;
ctx.executor().execute(new Runnable() {
@Override
public void run() {
MyContent content = packmsg.getContent();
String serviceName = content.getName();
String methodName = content.getMethodName();
Object[] args = content.getArgs();
Class<?>[] parameterTypes = content.getParameterTypes();
Object o = dispatcher.get(serviceName);
Class<?> clazz = o.getClass();
Object res = null;
try {
Method m = clazz.getMethod(methodName, parameterTypes);
res = m.invoke(o, args);
} catch (Exception e) {
e.printStackTrace();
}
MyContent resContent = new MyContent();
resContent.setRes(res);
byte[] contentByte = SerDerUtil.ser(resContent);
MyHeader resHeader = new MyHeader();
resHeader.setFlag(0X14141424);
resHeader.setRequestID(packmsg.getHeader().getRequestID());
resHeader.setDataLen(contentByte.length);
byte[] headerByte = SerDerUtil.ser(resHeader);
ByteBuf byteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(headerByte.length + contentByte.length);
byteBuf.writeBytes(headerByte);
byteBuf.writeBytes(contentByte);
ctx.writeAndFlush(byteBuf);
}
});
}
}
4.Future异步
package io.test.rpc;
import io.test.utils.Packmsg;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
public class ResponseMappingCallback {
static ConcurrentHashMap<Long, CompletableFuture> mapping = new ConcurrentHashMap();
public static void addCallBack(long requestID, CompletableFuture future){
mapping.putIfAbsent(requestID, future);
}
public static void runCallBack(Packmsg packmsg){
CompletableFuture completableFuture = mapping.get(packmsg.getHeader().getRequestID());
completableFuture.complete(packmsg.getContent().getRes());
removeCallBack(packmsg.getHeader().getRequestID());
}
public static void removeCallBack(long requestID){
mapping.remove(requestID);
}
}
5.协议
总:
MyContent
MyHeader
分:
package io.test.rpc.protocol;
import lombok.Data;
import lombok.ToString;
import java.io.Serializable;
import java.util.UUID;
/**
* @author jz
*/
@Data
@ToString
public class MyHeader implements Serializable {
private int flag;
private long requestID;
private long dataLen;
public static MyHeader createMyHeader(byte[] msg){
MyHeader header = new MyHeader();
int length = msg.length;
int flag = 0X14141414;
long requestID = Math.abs(UUID.randomUUID().getLeastSignificantBits());
header.setFlag(flag);
header.setRequestID(requestID);
header.setDataLen(length);
return header;
}
}
package io.test.rpc.protocol;
import lombok.Data;
import lombok.ToString;
import java.io.Serializable;
/**
* @author jz
*/
@Data
@ToString
public class MyContent implements Serializable {
private String name;
private String methodName;
private Class<?>[] parameterTypes;
private Object[] args;
private Object res;
}
6.服务service
package io.test.service;
public interface Car {
public String drive(String msg);
}
package io.test.service;
public class MyCar implements Car{
@Override
public String drive(String msg) {
return "开" + msg;
}
}
package io.test.service;
public interface Fly {
void fly(String msg);
}
package io.test.service;
public class MyFly implements Fly{
@Override
public void fly(String msg) {
System.out.println("草丛飞!" + msg);
}
}
7.工具类
package io.test.utils;
public class Constant {
// ClientFactory.transport
public static final int headerSize = 99;
}
package io.test.utils;
import io.test.rpc.protocol.MyContent;
import io.test.rpc.protocol.MyHeader;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
/**
* @author jz
*/
@AllArgsConstructor
@NoArgsConstructor
@Data
@ToString
public class Packmsg {
private MyHeader header;
private MyContent content;
}
package io.test.utils;
import java.io.*;
public class SerDerUtil {
static ByteArrayOutputStream out = new ByteArrayOutputStream();
public synchronized static byte[] ser(Object msg){
out.reset();
ObjectOutputStream oOut = null;
byte[] msgBody = null;
try {
oOut = new ObjectOutputStream(out);
oOut.writeObject(msg);
msgBody = out.toByteArray();
} catch (IOException e) {
e.printStackTrace();
}
return msgBody;
}
public static Object read(byte[] data) throws IOException, ClassNotFoundException {
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(data);
ObjectInputStream objectInputStream = new ObjectInputStream(byteArrayInputStream);
return objectInputStream.readObject();
}
}
8.最终执行
package io.test;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.test.proxy.MyProxy;
import io.test.rpc.Dispatcher;
import io.test.rpc.transport.ServerDecode;
import io.test.rpc.transport.ServerRequestHandler;
import io.test.service.Car;
import io.test.service.Fly;
import io.test.service.MyCar;
import io.test.service.MyFly;
import org.junit.Test;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.concurrent.atomic.AtomicInteger;
public class RpcTest {
@Test
public void startServer(){
MyCar car = new MyCar();
MyFly fly = new MyFly();
Dispatcher dis = Dispatcher.getInstance();
dis.register(Car.class.getName(),car);
dis.register(Fly.class.getName(),fly);
NioEventLoopGroup boss = new NioEventLoopGroup(1);
NioEventLoopGroup worker = boss;
ServerBootstrap sbs = new ServerBootstrap();
ChannelFuture bind = sbs.group(boss, worker)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<NioSocketChannel>() {
@Override
protected void initChannel(NioSocketChannel ch) throws Exception {
System.out.println("server accept cliet port: "+ ch.remoteAddress().getPort());
ChannelPipeline p = ch.pipeline();
p.addLast(new ServerDecode());
p.addLast(new ServerRequestHandler(dis));
}
}).bind(new InetSocketAddress("localhost", 9090));
try {
bind.sync().channel().closeFuture().sync();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
@Test
public void get(){
AtomicInteger num = new AtomicInteger(0);
int size = 50;
Thread[] threads = new Thread[size];
for (int i = 0; i <size; i++) {
threads[i] = new Thread(()->{
Car car = MyProxy.proxyGet(Car.class);//动态代理实现
String arg = "宝马X" + num.incrementAndGet();
String res = car.drive(arg);
System.out.println("client over msg: " + res+" src arg: "+ arg);
});
}
for (Thread thread : threads) {
thread.start();
}
try {
System.in.read();
} catch (IOException e) {
e.printStackTrace();
}
}
}