一、远程调用设计
二、客户端请求
2.1 创建调用器
调用器接口
public interface Invocation {
/**
* 获取服务名称
* @return
*/
String getServiceName();
/**
* 获取方法名称
*/
String getMethodName();
/**
* 获取方法参数类型
* @return
*/
Class<?>[] getParameterTypes();
/**
* 获取方法参数列表
* @return
*/
Object[] getArguments();
/**
* 超时时间
* @return
*/
long getTimeout();
/**
* 超时时间单位
* @return
*/
TimeUnit getTimeUnit();
}
默认rpc实现类,将调用方法的信息封装成Invocation
public class RpcInvocation implements Invocation {
private String serviceName;
private String methodName;
private Class<?>[] parameterTypes;
private Object[] arguments;
private long timeout = 3000;
private TimeUnit timeUnit = TimeUnit.MILLISECONDS;
public RpcInvocation(String serviceName, String methodName, Class<?>[] parameterTypes, Object[] arguments) {
this.serviceName = serviceName;
this.methodName = methodName;
this.parameterTypes = parameterTypes;
this.arguments = arguments;
}
public RpcInvocation(String serviceName, String methodName, Class<?>[] parameterTypes, Object[] arguments, long timeout, TimeUnit timeUnit) {
this(serviceName, methodName, parameterTypes, arguments);
this.timeout = timeout;
this.timeUnit = timeUnit;
}
@Override
public String getServiceName() {
return serviceName;
}
@Override
public String getMethodName() {
return methodName;
}
@Override
public Class<?>[] getParameterTypes() {
return parameterTypes;
}
@Override
public Object[] getArguments() {
return arguments;
}
@Override
public long getTimeout() {
return timeout;
}
@Override
public TimeUnit getTimeUnit() {
return timeUnit;
}
}
2.2 创建调用者Invoker
Invoker中存储了Netty的客户端,负责发送请求信息
public interface Invoker {
//执行请求
Result invoke(Invocation invocation) throws Throwable;
}
public class RpcInvoker implements Invoker {
/**
* 客户端
*/
private RpcClient client;
/**
* 客户端发送消息超时时间
*/
private long timeout;
public RpcInvoker(RpcClient client, long timeout) {
this.client = client;
this.timeout = timeout;
}
@Override
public Result invoke(Invocation invocation) throws Throwable {
//发送请求信息
CompletableFuture<AppResponse> appResponseFuture = client.request(invocation, timeout).thenApply(obj -> (AppResponse) obj);
//包装成异步返回结果
AsyncRpcResult result = new AsyncRpcResult(appResponseFuture, invocation);
//阻塞等待响应结果
result.get(invocation.getTimeout(), invocation.getTimeUnit());
return result;
}
}
2.3 客户端请求
RpcClient发送请求时会为每一个请求分配唯一的请求id和异步响应结果
/**
* 发送请求
*
* @param request 请求信息
* @param timeout 超时时间
* @throws Throwable
*/
public CompletableFuture<Object> request(Invocation request, long timeout) throws Throwable {
Request req = new Request();
req.setId();
req.setData(request);
DefaultFuture taskFuture = DefaultFuture.newFuture(channel, req, timeout);
send(req, timeout);
return taskFuture;
}
/**
* 发送消息
*/
public void send(Object request, long timeout) throws Throwable {
boolean success = true;
try {
ChannelFuture future = channel.writeAndFlush(request);
//同步等待发送结果,从而确保消息已经被成功发送到对端
success = future.await(timeout);
Throwable cause = future.cause();
if (cause != null) {
throw cause;
}
} catch (Throwable e) {
throw new Exception("Failed to send message :" + request +
" to " + channel.remoteAddress() + ", cause: " + e.getMessage(), e);
}
if (!success) {
throw new Exception("Failed to send message :" + request + " to " + channel.remoteAddress()
+ "in timeout(" + timeout + "ms) limit");
}
}
每一个请求都对应一个响应结果,根据请求id存储在任务列表FUTURES中
public class DefaultFuture extends CompletableFuture<Object> {
/**
* 请求任务列表
*/
private static final Map<Long, DefaultFuture> FUTURES = new ConcurrentHashMap<>();
/**
* 请求id
*/
private final Long id;
/**
* 请求通道
*/
private final Channel channel;
/**
* 请求消息
*/
private final Request request;
/**
* 超时时间
*/
private final long timeout;
private DefaultFuture(Channel channel, Request request, long timeout) {
this.channel = channel;
this.request = request;
this.id = request.getId();
this.timeout = timeout;
//将请求放入请求任务列表
FUTURES.put(id, this);
}
/**
* 创建请求任务
*
* @param channel 通信通道
* @param request 请求信息
* @param timeout 超时时间
* @return
*/
public static DefaultFuture newFuture(Channel channel, Request request, long timeout) {
final DefaultFuture future = new DefaultFuture(channel, request, timeout);
return future;
}
/**
* 获取任务
*
* @param id
* @return
*/
public static DefaultFuture getFuture(long id) {
return FUTURES.get(id);
}
}
2.4 处理服务端响应
服务端处理完请求会将请求id和响应结果返回,客户端根据id找到对应的异步结果将服务端响应信息进行写入,2.2步骤中的调用者get方法就会得响应信息
/**
* 获取服务端响应
*/
public class RpcClientHandler extends SimpleChannelInboundHandler<Response> {
private static final Logger logger = LoggerFactory.getLogger(RpcClientHandler.class);
@Override
protected void channelRead0(ChannelHandlerContext ctx, Response response) {
//获取任务
DefaultFuture responseFuture = DefaultFuture.getFuture(response.getId());
//响应异步结果
responseFuture.complete(response.getResult());
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
super.exceptionCaught(ctx, cause);
Channel channel = ctx.channel();
if (channel.isActive()) {
ctx.close();
}
}
}
三、服务端响应
服务端根据请求的服务名称找到对应服务,通过反射调用方法,获取到方法返回结果后将请求id和返回结果响应给客户端
3.1客户端请求处理
/**
* 使用ChannelHandler需要手动释放内存,否则会导致内存泄露问题
* 使用SimpleChannelHandler处理业务Netty底层会调用ReferenceCountUtil.release(msg)自动释放内存
* 具体案例:https://juejin.cn/post/7224886077051551781
*/
@ChannelHandler.Sharable
public class RpcServerHandler extends SimpleChannelInboundHandler<Request> {
private static final Logger logger = LoggerFactory.getLogger(RpcServerHandler.class);
private final RpcProtocol protocol;
public RpcServerHandler(RpcProtocol protocol) {
this.protocol = protocol;
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, Request req) {
Response res = new Response(req.getId());
Channel channel = ctx.channel();
Object msg = req.getData();
if ((msg instanceof Invocation)) {
Invocation invocation = (Invocation) msg;
try {
//执行服务
CompletableFuture<Object> future = protocol.reply(ctx, invocation);
//响应返回结果
future.whenComplete((appResult, t) -> {
try {
if (t == null) {
res.setStatus(ResponseEnum.SUCCESS.getCode());
res.setResult(appResult);
} else {
res.setStatus(ResponseEnum.ASYNC_RESPONSE_ERROR.getCode());
res.setErrorMessage(t.getMessage());
}
channel.writeAndFlush(res);
} catch (Exception e) {
logger.warn("Send result to consumer failed, channel is " + channel + ", msg is " + e);
}
});
} catch (Exception e) {
res.setStatus(ResponseEnum.ASYNC_RESPONSE_ERROR.getCode());
res.setErrorMessage(e.getMessage());
channel.writeAndFlush(res);
}
}
}
}
3.2 服务提供者
public class RpcProtocol {
/**
* 服务提供者 key:服务名称 value:提供对象
*/
protected final Map<String, Object> exporterMap = new ConcurrentHashMap();
public CompletableFuture<Object> reply(ChannelHandlerContext handlerContext, Invocation message) throws Exception {
String serviceName = message.getServiceName();
Object targetService = exporterMap.get(serviceName);
if (targetService == null) {
throw new RouteException(ResponseEnum.PROVIDER_ERROR.getCode(), "Not found exported service: " + serviceName + " in " + exporterMap.keySet());
}
//这里获取方法的异常和执行方法的异常交由最外层处理
Method declaredMethod = targetService.getClass().getDeclaredMethod(message.getMethodName(), message.getParameterTypes());
Object value = declaredMethod.invoke(targetService, message.getArguments());
CompletableFuture<Object> future = CompletableFuture.completedFuture(value);
//将返回结果包装成AppResponse
CompletableFuture<AppResponse> appResponseFuture = future.handle((obj, t) -> {
AppResponse result = new AppResponse();
if (t != null) {
result.setException(t);
} else {
result.setValue(obj);
}
return result;
});
return appResponseFuture.thenApply(Function.identity());
}
/**
*暴露服务提供者
*/
public <T> T export(String serviceName, T Object) {
return (T) exporterMap.put(serviceName, Object);
}
}
四、测试
4.1服务端启动并注册服务
public class RpcServerTest {
private static final Logger logger = LoggerFactory.getLogger(RpcServerTest.class);
@Test
public void testServer() throws ExecutionException, InterruptedException {
//创建Rpc协议
RpcProtocol rpcProtocol = new RpcProtocol();
//暴露服务
rpcProtocol.export("rpcTestImpl", new RpcTestImpl());
//创建客户端
RpcServer rpcServer = new RpcServer();
//启动服务端
Callable<ChannelFuture> callable = new Callable<ChannelFuture>() {
@Override
public ChannelFuture call() throws Exception {
//启动服务
return rpcServer.startApplication(7777, rpcProtocol);
}
};
Future<ChannelFuture> future = Executors.newFixedThreadPool(2).submit(callable);
ChannelFuture channelFuture = future.get();
Channel channel = channelFuture.channel();
if (null == channel) throw new RuntimeException("netty server start error channel is null");
while (!channel.isActive()) {
logger.info("NettyServer启动服务 ...");
Thread.sleep(500);
}
logger.info("NettyServer启动服务完成 {}", channel.localAddress());
CountDownLatch countDownLatch = new CountDownLatch(1);
countDownLatch.await();
}
}
4.2 客户端请求调用
public class RpcClientTest {
private static final Logger logger = LoggerFactory.getLogger(RpcClientTest.class);
@Test
public void testClient() throws Throwable {
//创建客户端
RpcClient rpcClient = new RpcClient("127.0.0.1", 7777, 3);
//创建调用者
RpcInvoker rpcInvoker = new RpcInvoker(rpcClient, 2000);
//调用目标方法
Method method = RpcTest.class.getMethods()[0];
RpcInvocation rpcInvocation = new RpcInvocation(
"rpcTestImpl",
method.getName(), method.getParameterTypes(),
new String[]{"你好啊!!"}
);
//执行调用远程调用
Result result = rpcInvoker.invoke(rpcInvocation);
logger.info("获取到的结果为:{}", result.getValue());
CountDownLatch countDownLatch = new CountDownLatch(1);
countDownLatch.await();
}
}