基于Netty实现的简单RPC框架
背景:
服务提供方有两个类:MyCar和MyFly,MyCar有一个ooxx方法,MyFly有一个xxoo方法。调用方要调用MyCar的ooxx方法,并且拿到方法执行的结果。通过Netty实现一个简单的RPC框架来实现这个功能。
服务提供方
MyCar和MyFly的定义
public interface Car {
String ooxx(String msg);
}
public interface Fly {
void xxoo(String msg);
}
// RPC要模拟调用的服务
public class MyCar implements Car {
@Override
public String ooxx(String msg) {
System.out.println("server,get client arg:" + msg);
return "server res " + msg;
}
}
public class MyFly implements Fly {
@Override
public void xxoo(String msg) {
System.out.println("server,get client arg:" + msg);
}
}
服务注册中心
服务注册中心是服务提供方将自己的服务集中存放的地方。当调用方的请求来了,服务提供方可以从注册中心拿到对应的服务进行调用。
// Dispatcher类对象作为服务注册中心
public class Dispatcher {
// 服务放到map中
public static ConcurrentHashMap<String, Object> invokeMap = new ConcurrentHashMap<>();
public void register(String k, Object obj) {
invokeMap.put(k, obj);
}
public Object get(String k) {
return invokeMap.get(k);
}
}
服务提供方
public class Server {
public static void main(String[] args) {
// MyCar和MyFly作为两个服务,用于RPC的调用
MyCar car = new MyCar();
MyFly fly = new MyFly();
// Dispatcher作为服务的注册中心
Dispatcher dis = new Dispatcher();
// 将MyCar和MyFly对象注册到Dispatcher中
// RPC调用时,可以从Dispatcher中获取
dis.register(Car.class.getName(), car);
dis.register(Fly.class.getName(), fly);
// 初始化Netty服务端的NioEventLoopGroup,并且boss和worker用同一个NioEventLoopGroup
// NioEventLoopGroup的每一个线程绑定一个selector多路复用器
// 所以每个线程都会做EventLoop中的三件事
// 第一件:调用select方法,拿到准备就绪的fd集合
// 第二件:遍历集合,处理io事件
// 第三件:遍历任务集合,看是否有新的事件需要注册到多路复用器
NioEventLoopGroup boss = new NioEventLoopGroup(20);
NioEventLoopGroup worker = boss;
// 初始化Netty的服务端
ServerBootstrap sbs = new ServerBootstrap();
ChannelFuture bind = sbs.group(boss, worker)
.channel(NioServerSocketChannel.class)
// accept事件已经被Netty默默处理了
// 此处只需要将read事件触发后,需要执行的逻辑写好
.childHandler(new ChannelInitializer<NioSocketChannel>() {
@Override
protected void initChannel(NioSocketChannel ch) throws Exception {
System.out.println("server accept cliet port: " + ch.remoteAddress().getPort());
ChannelPipeline p = ch.pipeline();
// 在ChannelPipeline中放入两个对象
// read事件触发后,先调用ServerDecode对象的decode方法,解析数据
// 然后调用ServerRequestHandler对象的channelRead方法
p.addLast(new ServerDecode());
p.addLast(new ServerRequestHandler(dis));
}
}).bind(new InetSocketAddress("localhost", 9090));
try {
bind.sync().channel().closeFuture().sync();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
解码器
/**
* 解码器
* 内核为socket两端提供了recv_queue缓存区,数据会先从网卡读到缓存区中
* 当数据到达recv_queue中后,read事件触发,Netty中的EventLoop中的线程会从recv_queue中拿数据到ByteBuf中
* 但是,当线程从recv_queue中拿数据到ByteBuf中之前,可能recv_queue中已经存放了不止一条数据了,内核是无法保证recv_queue中数据的完整性的
* 所以EventLoop中线程拿到数据放到ByteBuf中后,需要我们对ByteBuf中数据进行切分,切分出一条条完整的数据
* 因此,Netty提供了解码器,实现这一功能
*/
public class ServerDecode extends ByteToMessageDecoder {
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf buf, List<Object> out) throws Exception {
// 写代码前已经测试过了,模拟数据的Header数据的长度是110个字节
// 这里判断ByteBuf中的数据是否大于等于110,如果大于,说明有可能是个完整的数据,否则一定不是完整的
// 不是完整的数据,不要进行任何处理
// 因为默认会将不完整的数据留到下一次数据到来时,和下一批数据一起处理
// 这次不完整的数据的剩余部分就在下一批数据当中
while (buf.readableBytes() > 110) {
byte[] bytes = new byte[110];
// 获取ByteBuf中的Header数据
// 但是不移动readerIndex
buf.getBytes(buf.readerIndex(), bytes);
ByteArrayInputStream in = new ByteArrayInputStream(bytes);
ObjectInputStream oin = new ObjectInputStream(in);
Myheader header = (Myheader) oin.readObject();
// 因为上面的操作并没有移动readerIndex,所以此时还是从Header头部计算可读字节数
// 判断可读字节数是否大于等于一个完整数据的字节数
if (buf.readableBytes() >= (header.getDataLen() + 110)) {
// 先移动指针到body开始的位置
buf.readBytes(110);
byte[] data = new byte[(int) header.getDataLen()];
buf.readBytes(data);
ByteArrayInputStream din = new ByteArrayInputStream(data);
ObjectInputStream doin = new ObjectInputStream(din);
// flag == 0x14141414 说明此次数据是调用方发送给服务提供方的数据
// flag == 0x14141424 说明此次数据是服务提供方返回给调用方的数据,这个数据应该包含了方法的执行结果
if (header.getFlag() == 0x14141414) {
MyContent content = (MyContent) doin.readObject();
out.add(new Packmsg(header, content));
} else if (header.getFlag() == 0x14141424) {
MyContent content = (MyContent) doin.readObject();
out.add(new Packmsg(header, content));
}
} else {
// ByteBuf中剩余的数据不是完整的数据
// 不做任何处理,跳出循环,留着数据和下一批数据一起处理
break;
}
}
}
}
ServerRequestHandler
// 处理数据的业务代码
public class ServerRequestHandler extends ChannelInboundHandlerAdapter {
Dispatcher dis;
public ServerRequestHandler(Dispatcher dis) {
this.dis = dis;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
// 拿到解码器处理后的完整的数据
Packmsg requestPkg = (Packmsg) msg;
// 这里有几种方式可以处理数据
// 第一种是一个线程执行,直到业务代码执行结束
// 第二种是另起一个线程执行业务代码,处理io的线程可以快速返回执行其他io
// 第三种是使用EventLoop中的线程,这种方式和第一种方式其实是一样的,都是用的io线程处理
// 第四种是使用EventLoopGroup中其他EventLoop的线程,这种方式的好处是,第一不会创建出太多的线程,第二其他EventLoop中的线程可能是空闲的,可以充分利用
// 这种是使用EventLoop中的线程
ctx.executor().execute(new Runnable() {
// 这种是使用EventLoopGroup中其他EventLoop的线程
// ctx.executor().parent().next().execute(new Runnable() {
@Override
public void run() {
// 获取服务名称
String serviceName = requestPkg.content.getName();
// 获取方法名称
String method = requestPkg.content.getMethodName();
// 根据服务名称拿到服务对象
Object c = dis.get(serviceName);
// 根据服务对象拿到类对象
Class<?> clazz = c.getClass();
Object res = null;
try {
// 通过反射调用服务对象的方法
Method m = clazz.getMethod(method, requestPkg.content.parameterTypes);
res = m.invoke(c, requestPkg.content.getArgs());
} catch (NoSuchMethodException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
} catch (InvocationTargetException e) {
e.printStackTrace();
}
MyContent content = new MyContent();
// 将方法的返回结果放到content中,准备返回给调用方
content.setRes((String) res);
// 将响应数据序列化
byte[] contentByte = SerDerUtil.ser(content);
// 生成响应数据Header
Myheader resHeader = new Myheader();
// 设置Header中的属性值,RequestID必须是请求体中的RequestID
// 这样调用方才知道让哪个线程来处理
resHeader.setRequestID(requestPkg.header.getRequestID());
resHeader.setFlag(0x14141424);
resHeader.setDataLen(contentByte.length);
byte[] headerByte = SerDerUtil.ser(resHeader);
ByteBuf byteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(headerByte.length + contentByte.length);
byteBuf.writeBytes(headerByte);
byteBuf.writeBytes(contentByte);
ctx.writeAndFlush(byteBuf);
}
});
}
}
RPC通信的协议类
public class Packmsg {
Myheader header;
MyContent content;
public Myheader getHeader() {
return header;
}
public void setHeader(Myheader header) {
this.header = header;
}
public MyContent getContent() {
return content;
}
public void setContent(MyContent content) {
this.content = content;
}
public Packmsg(Myheader header, MyContent content) {
this.header = header;
this.content = content;
}
}
public class Myheader implements Serializable {
//通信上的协议
// 标识属性
// 32位可以设置很多信息
int flag;
// 一次通信的唯一标识
long requestID;
// 通信的业务数据的大小
long dataLen;
public int getFlag() {
return flag;
}
public void setFlag(int flag) {
this.flag = flag;
}
public long getRequestID() {
return requestID;
}
public void setRequestID(long requestID) {
this.requestID = requestID;
}
public long getDataLen() {
return dataLen;
}
public void setDataLen(long dataLen) {
this.dataLen = dataLen;
}
public static Myheader createHeader(byte[] msg) {
Myheader header = new Myheader();
// Header中的三个参数
// 数据体长度
int size = msg.length;
// 标志属性
int f = 0x14141414;
// requestID作为这次通信的唯一标识
long requestID = Math.abs(UUID.randomUUID().getLeastSignificantBits());
header.setFlag(f);
header.setDataLen(size);
header.setRequestID(requestID);
return header;
}
}
public class MyContent implements Serializable {
// 服务名称
String name;
// 方法名称
String methodName;
// 方法的参数类型
Class<?>[] parameterTypes;
// 方法的参数列表
Object[] args;
// 方法的返回值
String res;
public String getRes() {
return res;
}
public void setRes(String res) {
this.res = res;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getMethodName() {
return methodName;
}
public void setMethodName(String methodName) {
this.methodName = methodName;
}
public Class<?>[] getParameterTypes() {
return parameterTypes;
}
public void setParameterTypes(Class<?>[] parameterTypes) {
this.parameterTypes = parameterTypes;
}
public Object[] getArgs() {
return args;
}
public void setArgs(Object[] args) {
this.args = args;
}
}
数据序列化的工具类
public class SerDerUtil {
static ByteArrayOutputStream out = new ByteArrayOutputStream();
public synchronized static byte[] ser(Object msg){
out.reset();
ObjectOutputStream oout = null;
byte[] msgBody = null;
try {
oout = new ObjectOutputStream(out);
oout.writeObject(msg);
msgBody= out.toByteArray();
} catch (IOException e) {
e.printStackTrace();
}
return msgBody;
}
}
调用方
需要调用的服务接口
public interface Car {
String ooxx(String msg);
}
动态代理类
调用方先获得Car的动态代理对象,然后调用代理对象的ooxx方法,实现RPC调用
public class MyProxy {
// 返回动态代理对象
public static <T> T proxyGet(Class<T> interfaceInfo) {
ClassLoader loader = interfaceInfo.getClassLoader();
Class<?>[] interfaces = {interfaceInfo};
return (T) Proxy.newProxyInstance(loader, interfaces, new InvocationHandler() {
// 调用代理对象的方法时,实际执行invoke方法
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
// 拿到类的全限定名
String name = interfaceInfo.getName();
// 拿到调用的实际的方法名
String methodName = method.getName();
// 拿到实际方法的参数类型
Class<?>[] parameterTypes = method.getParameterTypes();
// MyContent对象用于封装以上信息
// 在要调用其他服务的某个方法时,以上信息可以精确定位到方法
// 所以以上信息是RPC的通信协议中不可或缺的信息
MyContent content = new MyContent();
content.setArgs(args);
content.setName(name);
content.setMethodName(methodName);
content.setParameterTypes(parameterTypes);
byte[] msgBody = com.bjmashibing.system.rpcdemo.util.SerDerUtil.ser(content);
Myheader header = Myheader.createHeader(msgBody);
byte[] msgHeader = com.bjmashibing.system.rpcdemo.util.SerDerUtil.ser(header);
// 这里打印出通信头部的长度大小
System.out.println("msgHeader的长度: " + msgHeader.length);
// 从连接池中获取连接
ClientFactory factory = ClientFactory.getFactory();
NioSocketChannel clientChannel = factory.getClient(new InetSocketAddress("localhost", 9090));
// 发送数据前,准备一个ByteBuf
ByteBuf byteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(msgHeader.length + msgBody.length);
// 拿到这次通信的RequestID,以便返回响应时,可以找到对应的处理线程
long id = header.getRequestID();
// 初始化一个CompletableFuture对象
// 其get方法是堵塞的,等返回响应后,将响应数据放到CompletableFuture对象中
// 使用get方法可以拿到
CompletableFuture<String> res = new CompletableFuture<>();
// 将RequestID和CompletableFuture对象放到map中
ResponseMappingCallback.addCallBack(id, res);
byteBuf.writeBytes(msgHeader);
byteBuf.writeBytes(msgBody);
// 发送数据
ChannelFuture channelFuture = clientChannel.writeAndFlush(byteBuf);
channelFuture.sync();
// CompletableFuture对象中没有数据时就是堵塞的
// 等返回响应,CompletableFuture对象中有数据就不堵塞了
// 而且可以拿到响应数据
return res.get();
}
});
}
}
RPC通信的协议类
public class Packmsg {
Myheader header;
MyContent content;
public Myheader getHeader() {
return header;
}
public void setHeader(Myheader header) {
this.header = header;
}
public MyContent getContent() {
return content;
}
public void setContent(MyContent content) {
this.content = content;
}
public Packmsg(Myheader header, MyContent content) {
this.header = header;
this.content = content;
}
}
public class Myheader implements Serializable {
//通信上的协议
// 标识属性
// 32位可以设置很多信息
int flag;
// 一次通信的唯一标识
long requestID;
// 通信的业务数据的大小
long dataLen;
public int getFlag() {
return flag;
}
public void setFlag(int flag) {
this.flag = flag;
}
public long getRequestID() {
return requestID;
}
public void setRequestID(long requestID) {
this.requestID = requestID;
}
public long getDataLen() {
return dataLen;
}
public void setDataLen(long dataLen) {
this.dataLen = dataLen;
}
public static Myheader createHeader(byte[] msg) {
Myheader header = new Myheader();
// Header中的三个参数
// 数据体长度
int size = msg.length;
// 标志属性
int f = 0x14141414;
// requestID作为这次通信的唯一标识
long requestID = Math.abs(UUID.randomUUID().getLeastSignificantBits());
header.setFlag(f);
header.setDataLen(size);
header.setRequestID(requestID);
return header;
}
}
public class MyContent implements Serializable {
// 服务名称
String name;
// 方法名称
String methodName;
// 方法的参数类型
Class<?>[] parameterTypes;
// 方法的参数列表
Object[] args;
// 方法的返回值
String res;
public String getRes() {
return res;
}
public void setRes(String res) {
this.res = res;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getMethodName() {
return methodName;
}
public void setMethodName(String methodName) {
this.methodName = methodName;
}
public Class<?>[] getParameterTypes() {
return parameterTypes;
}
public void setParameterTypes(Class<?>[] parameterTypes) {
this.parameterTypes = parameterTypes;
}
public Object[] getArgs() {
return args;
}
public void setArgs(Object[] args) {
this.args = args;
}
}
数据序列化的工具类
public class SerDerUtil {
static ByteArrayOutputStream out = new ByteArrayOutputStream();
public synchronized static byte[] ser(Object msg){
out.reset();
ObjectOutputStream oout = null;
byte[] msgBody = null;
try {
oout = new ObjectOutputStream(out);
oout.writeObject(msg);
msgBody= out.toByteArray();
} catch (IOException e) {
e.printStackTrace();
}
return msgBody;
}
}
调用方启动类
public class Client {
public static void main(String[] args) {
AtomicInteger num = new AtomicInteger(0);
// 准备50个线程作为客户端
int size = 50;
Thread[] threads = new Thread[size];
for (int i = 0; i < size; i++) {
threads[i] = new Thread(() -> {
// 拿到动态代理的Car类对象
Car car = MyProxy.proxyGet(Car.class);
String arg = "hello" + num.incrementAndGet();
// 调用代理对象的ooxx方法
String res = car.ooxx(arg);
System.out.println("client over msg: " + res + " src arg: " + arg);
});
}
for (Thread thread : threads) {
thread.start();
}
try {
System.in.read();
} catch (IOException e) {
e.printStackTrace();
}
}
}
客户端线程池
// 连接池工厂
public class ClientFactory {
int poolSize = 1;
NioEventLoopGroup clientWorker;
Random rand = new Random();
private ClientFactory() {
}
// 单例模式
private static final ClientFactory factory;
static {
factory = new ClientFactory();
}
public static ClientFactory getFactory() {
return factory;
}
// map中存放连接池
// key是连接的服务端的IP+PORT信息,value是连接池
ConcurrentHashMap<InetSocketAddress, ClientPool> outboxs = new ConcurrentHashMap<>();
public synchronized NioSocketChannel getClient(InetSocketAddress address) {
// 从map中获取连接池
ClientPool clientPool = outboxs.get(address);
if (clientPool == null) {
// 连接池为空
// 创建连接池
outboxs.putIfAbsent(address, new ClientPool(poolSize));
clientPool = outboxs.get(address);
}
// 随机获得一个数,小于连接池的容量
int i = rand.nextInt(poolSize);
// 从连接池中获取连接,判断连接是否活跃
if (clientPool.clients[i] != null && clientPool.clients[i].isActive()) {
// 连接可用,直接返回连接
return clientPool.clients[i];
}
// 创建连接,这里会有多个线程执行,所以要加锁
synchronized (clientPool.lock[i]) {
if (clientPool.clients[i] != null && clientPool.clients[i].isActive()) {
return clientPool.clients[i];
}
// 创建连接
return clientPool.clients[i] = create(address);
}
}
// 创建连接
private NioSocketChannel create(InetSocketAddress address) {
//基于 netty 的客户端创建方式
clientWorker = new NioEventLoopGroup(1);
Bootstrap bs = new Bootstrap();
ChannelFuture connect = bs.group(clientWorker)
.channel(NioSocketChannel.class)
.handler(new ChannelInitializer<NioSocketChannel>() {
@Override
protected void initChannel(NioSocketChannel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
// 调用方拿到服务提供方返回的响应数据也需要解码
p.addLast(new ServerDecode());
p.addLast(new ClientResponses());
}
}).connect(address);
try {
NioSocketChannel client = (NioSocketChannel) connect.sync().channel();
return client;
} catch (InterruptedException e) {
e.printStackTrace();
}
return null;
}
}
public class ClientPool {
NioSocketChannel[] clients;
Object[] lock;
ClientPool(int size) {
clients = new NioSocketChannel[size];
// 锁对象,专门加锁用的
lock = new Object[size];
for (int i = 0; i < size; i++) {
lock[i] = new Object();
}
}
}
响应处理类
public class ResponseMappingCallback {
static ConcurrentHashMap<Long, CompletableFuture> mapping = new ConcurrentHashMap<>();
// 将requestID作为key,CompletableFuture对象作为value,放到map中
public static void addCallBack(long requestID, CompletableFuture cb) {
mapping.putIfAbsent(requestID, cb);
}
// 返回响应时,调用此方法
public static void runCallBack(Packmsg msg) {
// 根据响应数据中的requestID,从map中取出对应的CompletableFuture对象
CompletableFuture cf = mapping.get(msg.header.getRequestID());
// 把响应数据放到CompletableFuture对象中
// 如此一来,CompletableFuture对象的get方法可以拿到数据,不在堵塞
cf.complete(msg.getContent().getRes());
// 这次通信结束
// 从map中将记录删除
removeCB(msg.header.getRequestID());
}
private static void removeCB(long requestID) {
mapping.remove(requestID);
}
}
// 调用方接收响应后执行的业务代码
public class ClientResponses extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
// 拿到服务提供方返回的数据
Packmsg responsepkg = (Packmsg) msg;
// 调用runCallBack方法,将数据交给对应的线程处理
ResponseMappingCallback.runCallBack(responsepkg);
}
}
解码器
/**
* 解码器
* 内核为socket两端提供了recv_queue缓存区,数据会先从网卡读到缓存区中
* 当数据到达recv_queue中后,read事件触发,Netty中的EventLoop中的线程会从recv_queue中拿数据到ByteBuf中
* 但是,当线程从recv_queue中拿数据到ByteBuf中之前,可能recv_queue中已经存放了不止一条数据了,内核是无法保证recv_queue中数据的完整性的
* 所以EventLoop中线程拿到数据放到ByteBuf中后,需要我们对ByteBuf中数据进行切分,切分出一条条完整的数据
* 因此,Netty提供了解码器,实现这一功能
*/
public class ServerDecode extends ByteToMessageDecoder {
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf buf, List<Object> out) throws Exception {
// 写代码前已经测试过了,模拟数据的Header数据的长度是110个字节
// 这里判断ByteBuf中的数据是否大于等于110,如果大于,说明有可能是个完整的数据,否则一定不是完整的
// 不是完整的数据,不要进行任何处理
// 因为默认会将不完整的数据留到下一次数据到来时,和下一批数据一起处理
// 这次不完整的数据的剩余部分就在下一批数据当中
while (buf.readableBytes() > 110) {
byte[] bytes = new byte[110];
// 获取ByteBuf中的Header数据
// 但是不移动readerIndex
buf.getBytes(buf.readerIndex(), bytes);
ByteArrayInputStream in = new ByteArrayInputStream(bytes);
ObjectInputStream oin = new ObjectInputStream(in);
Myheader header = (Myheader) oin.readObject();
// 因为上面的操作并没有移动readerIndex,所以此时还是从Header头部计算可读字节数
// 判断可读字节数是否大于等于一个完整数据的字节数
if (buf.readableBytes() >= (header.getDataLen() + 110)) {
// 先移动指针到body开始的位置
buf.readBytes(110);
byte[] data = new byte[(int) header.getDataLen()];
buf.readBytes(data);
ByteArrayInputStream din = new ByteArrayInputStream(data);
ObjectInputStream doin = new ObjectInputStream(din);
// flag == 0x14141414 说明此次数据是调用方发送给服务提供方的数据
// flag == 0x14141424 说明此次数据是服务提供方返回给调用方的数据,这个数据应该包含了方法的执行结果
if (header.getFlag() == 0x14141414) {
MyContent content = (MyContent) doin.readObject();
out.add(new Packmsg(header, content));
} else if (header.getFlag() == 0x14141424) {
MyContent content = (MyContent) doin.readObject();
out.add(new Packmsg(header, content));
}
} else {
// ByteBuf中剩余的数据不是完整的数据
// 不做任何处理,跳出循环,留着数据和下一批数据一起处理
break;
}
}
}
}