手写RPC框架<二>远程调用

350 阅读4分钟

一、远程调用设计

image.png

二、客户端请求

2.1 创建调用器

调用器接口

public interface Invocation {

    /**
     * 获取服务名称
     * @return
     */
    String getServiceName();

    /**
     * 获取方法名称
     */
    String getMethodName();

    /**
     * 获取方法参数类型
     * @return
     */
    Class<?>[] getParameterTypes();

    /**
     * 获取方法参数列表
     * @return
     */
    Object[] getArguments();

    /**
     * 超时时间
     * @return
     */
    long getTimeout();

    /**
     * 超时时间单位
     * @return
     */
    TimeUnit getTimeUnit();
}

默认rpc实现类,将调用方法的信息封装成Invocation

public class RpcInvocation implements Invocation {
    
    private String serviceName;

    private String methodName;

    private Class<?>[] parameterTypes;

    private Object[] arguments;

    private long timeout = 3000;

    private TimeUnit timeUnit = TimeUnit.MILLISECONDS;

    public RpcInvocation(String serviceName, String methodName, Class<?>[] parameterTypes, Object[] arguments) {
        this.serviceName = serviceName;
        this.methodName = methodName;
        this.parameterTypes = parameterTypes;
        this.arguments = arguments;
    }

    public RpcInvocation(String serviceName, String methodName, Class<?>[] parameterTypes, Object[] arguments, long timeout, TimeUnit timeUnit) {
        this(serviceName, methodName, parameterTypes, arguments);
        this.timeout = timeout;
        this.timeUnit = timeUnit;
    }

    @Override
    public String getServiceName() {
        return serviceName;
    }

    @Override
    public String getMethodName() {
        return methodName;
    }

    @Override
    public Class<?>[] getParameterTypes() {
        return parameterTypes;
    }

    @Override
    public Object[] getArguments() {
        return arguments;
    }

    @Override
    public long getTimeout() {
        return timeout;
    }

    @Override
    public TimeUnit getTimeUnit() {
        return timeUnit;
    }
}

2.2 创建调用者Invoker

Invoker中存储了Netty的客户端,负责发送请求信息

public interface Invoker {
    //执行请求
    Result invoke(Invocation invocation) throws Throwable;
}
public class RpcInvoker implements Invoker {

    /**
     * 客户端
     */
    private RpcClient client;

    /**
     * 客户端发送消息超时时间
     */
    private long timeout;

    public RpcInvoker(RpcClient client, long timeout) {
        this.client = client;
        this.timeout = timeout;
    }

    @Override
    public Result invoke(Invocation invocation) throws Throwable {
        //发送请求信息
        CompletableFuture<AppResponse> appResponseFuture = client.request(invocation, timeout).thenApply(obj -> (AppResponse) obj);
        //包装成异步返回结果
        AsyncRpcResult result = new AsyncRpcResult(appResponseFuture, invocation);
        //阻塞等待响应结果
        result.get(invocation.getTimeout(), invocation.getTimeUnit());
        return result;
    }

}

2.3 客户端请求

RpcClient发送请求时会为每一个请求分配唯一的请求id和异步响应结果

/**
 * 发送请求
 *
 * @param request 请求信息
 * @param timeout 超时时间
 * @throws Throwable
 */
public CompletableFuture<Object> request(Invocation request, long timeout) throws Throwable {
    Request req = new Request();
    req.setId();
    req.setData(request);
    DefaultFuture taskFuture = DefaultFuture.newFuture(channel, req, timeout);
    send(req, timeout);
    return taskFuture;
}


/**
 * 发送消息
 */
public void send(Object request, long timeout) throws Throwable {
    boolean success = true;
    try {
        ChannelFuture future = channel.writeAndFlush(request);
        //同步等待发送结果,从而确保消息已经被成功发送到对端
        success = future.await(timeout);
        Throwable cause = future.cause();
        if (cause != null) {
            throw cause;
        }
    } catch (Throwable e) {
        throw new Exception("Failed to send message :" + request +
                " to " + channel.remoteAddress() + ", cause: " + e.getMessage(), e);
    }
    if (!success) {
        throw new Exception("Failed to send message :" + request + " to " + channel.remoteAddress()
                + "in timeout(" + timeout + "ms) limit");
    }
}

每一个请求都对应一个响应结果,根据请求id存储在任务列表FUTURES中

public class DefaultFuture extends CompletableFuture<Object> {

    /**
     * 请求任务列表
     */
    private static final Map<Long, DefaultFuture> FUTURES = new ConcurrentHashMap<>();

    /**
     * 请求id
     */
    private final Long id;
    /**
     * 请求通道
     */
    private final Channel channel;
    /**
     * 请求消息
     */
    private final Request request;
    /**
     * 超时时间
     */
    private final long timeout;

    private DefaultFuture(Channel channel, Request request, long timeout) {
        this.channel = channel;
        this.request = request;
        this.id = request.getId();
        this.timeout = timeout;
        //将请求放入请求任务列表
        FUTURES.put(id, this);
    }


    /**
     * 创建请求任务
     *
     * @param channel 通信通道
     * @param request 请求信息
     * @param timeout 超时时间
     * @return
     */
    public static DefaultFuture newFuture(Channel channel, Request request, long timeout) {
        final DefaultFuture future = new DefaultFuture(channel, request, timeout);
        return future;
    }

    /**
     * 获取任务
     *
     * @param id
     * @return
     */
    public static DefaultFuture getFuture(long id) {
        return FUTURES.get(id);
    }


}

2.4 处理服务端响应

服务端处理完请求会将请求id和响应结果返回,客户端根据id找到对应的异步结果将服务端响应信息进行写入,2.2步骤中的调用者get方法就会得响应信息

/**
 * 获取服务端响应
 */
public class RpcClientHandler extends SimpleChannelInboundHandler<Response> {

    private static final Logger logger = LoggerFactory.getLogger(RpcClientHandler.class);

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, Response response) {
        //获取任务
        DefaultFuture responseFuture = DefaultFuture.getFuture(response.getId());
        //响应异步结果
        responseFuture.complete(response.getResult());
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        super.exceptionCaught(ctx, cause);
        Channel channel = ctx.channel();
        if (channel.isActive()) {
            ctx.close();
        }
    }

}

三、服务端响应

服务端根据请求的服务名称找到对应服务,通过反射调用方法,获取到方法返回结果后将请求id和返回结果响应给客户端

3.1客户端请求处理

/**
 * 使用ChannelHandler需要手动释放内存,否则会导致内存泄露问题
 * 使用SimpleChannelHandler处理业务Netty底层会调用ReferenceCountUtil.release(msg)自动释放内存
 * 具体案例:https://juejin.cn/post/7224886077051551781
 */
@ChannelHandler.Sharable
public class RpcServerHandler extends SimpleChannelInboundHandler<Request> {

    private static final Logger logger = LoggerFactory.getLogger(RpcServerHandler.class);

    private final RpcProtocol protocol;

    public RpcServerHandler(RpcProtocol protocol) {
        this.protocol = protocol;
    }

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, Request req) {
        Response res = new Response(req.getId());
        Channel channel = ctx.channel();
        Object msg = req.getData();
        if ((msg instanceof Invocation)) {
            Invocation invocation = (Invocation) msg;
            try {
                //执行服务
                CompletableFuture<Object> future = protocol.reply(ctx, invocation);
                //响应返回结果
                future.whenComplete((appResult, t) -> {
                    try {
                        if (t == null) {
                            res.setStatus(ResponseEnum.SUCCESS.getCode());
                            res.setResult(appResult);
                        } else {
                            res.setStatus(ResponseEnum.ASYNC_RESPONSE_ERROR.getCode());
                            res.setErrorMessage(t.getMessage());
                        }
                        channel.writeAndFlush(res);
                    } catch (Exception e) {
                        logger.warn("Send result to consumer failed, channel is " + channel + ", msg is " + e);
                    }

                });

            } catch (Exception e) {
                res.setStatus(ResponseEnum.ASYNC_RESPONSE_ERROR.getCode());
                res.setErrorMessage(e.getMessage());
                channel.writeAndFlush(res);
            }
        }
    }
}

3.2 服务提供者

public class RpcProtocol {

    /**
     * 服务提供者 key:服务名称 value:提供对象
     */
    protected final Map<String, Object> exporterMap = new ConcurrentHashMap();

    public CompletableFuture<Object> reply(ChannelHandlerContext handlerContext, Invocation message) throws Exception {
        String serviceName = message.getServiceName();
        Object targetService = exporterMap.get(serviceName);
        if (targetService == null) {
            throw new RouteException(ResponseEnum.PROVIDER_ERROR.getCode(), "Not found exported service: " + serviceName + " in " + exporterMap.keySet());
        }
        //这里获取方法的异常和执行方法的异常交由最外层处理
        Method declaredMethod = targetService.getClass().getDeclaredMethod(message.getMethodName(), message.getParameterTypes());
        Object value = declaredMethod.invoke(targetService, message.getArguments());

        CompletableFuture<Object> future = CompletableFuture.completedFuture(value);
        //将返回结果包装成AppResponse
        CompletableFuture<AppResponse> appResponseFuture = future.handle((obj, t) -> {
            AppResponse result = new AppResponse();
            if (t != null) {
                result.setException(t);
            } else {
                result.setValue(obj);
            }
            return result;
        });

        return appResponseFuture.thenApply(Function.identity());
    }

    /**
     *暴露服务提供者
     */
    public <T> T export(String serviceName, T Object) {
        return (T) exporterMap.put(serviceName, Object);
    }
}

四、测试

4.1服务端启动并注册服务

public class RpcServerTest {

    private static final Logger logger = LoggerFactory.getLogger(RpcServerTest.class);

    @Test
    public void testServer() throws ExecutionException, InterruptedException {
        //创建Rpc协议
        RpcProtocol rpcProtocol = new RpcProtocol();
        //暴露服务
        rpcProtocol.export("rpcTestImpl", new RpcTestImpl());
        //创建客户端
        RpcServer rpcServer = new RpcServer();
        //启动服务端
        Callable<ChannelFuture> callable = new Callable<ChannelFuture>() {
            @Override
            public ChannelFuture call() throws Exception {
                //启动服务
                return rpcServer.startApplication(7777, rpcProtocol);
            }
        };
        Future<ChannelFuture> future = Executors.newFixedThreadPool(2).submit(callable);
        ChannelFuture channelFuture = future.get();
        Channel channel = channelFuture.channel();
        if (null == channel) throw new RuntimeException("netty server start error channel is null");
        while (!channel.isActive()) {
            logger.info("NettyServer启动服务 ...");
            Thread.sleep(500);
        }
        logger.info("NettyServer启动服务完成 {}", channel.localAddress());

        CountDownLatch countDownLatch = new CountDownLatch(1);
        countDownLatch.await();
    }

}

4.2 客户端请求调用

public class RpcClientTest {

    private static final Logger logger = LoggerFactory.getLogger(RpcClientTest.class);

    @Test
    public void testClient() throws Throwable {
        //创建客户端
        RpcClient rpcClient = new RpcClient("127.0.0.1", 7777, 3);
        //创建调用者
        RpcInvoker rpcInvoker = new RpcInvoker(rpcClient, 2000);

        //调用目标方法
        Method method = RpcTest.class.getMethods()[0];
        RpcInvocation rpcInvocation = new RpcInvocation(
                "rpcTestImpl",
                method.getName(), method.getParameterTypes(),
                new String[]{"你好啊!!"}
        );
        //执行调用远程调用
        Result result = rpcInvoker.invoke(rpcInvocation);
        logger.info("获取到的结果为:{}", result.getValue());

        CountDownLatch countDownLatch = new CountDownLatch(1);
        countDownLatch.await();
    }


}

image.png