Springboot整合Netty实现RPC服务器

1,962 阅读6分钟

Springboot整合Netty实现RPC服务器

一、什么是RPC?

RPC(Remote Procedure Call)远程过程调用,是一种进程间的通信方式,其可以做到像调用本地方法那样调用位于远程的计算机的服务。其实现的原理过程如下:

  • 本地的进程通过接口进行本地方法调用。
  • RPC客户端将调用的接口名、接口方法、方法参数等信息利用网络通信发送给RPC服务器。
  • RPC服务器对请求进行解析,根据接口名、接口方法、方法参数等信息找到对应的方法实现,并进行本地方法调用,然后将方法调用结果响应给RPC客户端。

二、实现RPC需要解决那些问题?

1. 约定通信协议格式

RPC分为客户端与服务端,就像HTTP一样,我们需要定义交互的协议格式。主要包括三个方面:

  • 请求格式
  • 响应格式
  • 网络通信时数据的序列化方式

RPC请求

@Data
public class RpcRequest {
    /**
     * 请求ID 用来标识本次请求以匹配RPC服务器的响应
     */
    private String requestId;
    /**
     * 调用的类(接口)权限定名称
     */
    private String className;
    /**
     * 调用的方法名
     */
    private String methodName;
    /**
     * 方法参类型列表
     */
    private Class<?>[] parameterTypes;
    /**
     * 方法参数
     */
    private Object[] parameters;
}

RPC响应

@Data
public class RpcResponse {
    /**
     * 响应对应的请求ID
     */
    private String requestId;
    /**
     * 调用是否成功的标识
     */
    private boolean success = true;
    /**
     * 调用错误信息
     */
    private String errorMessage;
    /**
     * 调用结果
     */
    private Object result;
}

2. 序列化方式

序列化方式可以使用JDK自带的序列化方式或者一些第三方的序列化方式,JDK自带的由于性能较差所以不推荐。我们这里选择JSON作为序列化协议,即将请求和响应对象序列化为JSON字符串后发送到对端,对端接收到后反序列为相应的对象,这里采用阿里的 fastjson 作为JSON序列化框架。

3. TCP粘包、拆包

TCP是个“流”协议,所谓流,就是没有界限的一串数据。大家可以想想河里的流水,是连成一片的,其间并没有分界线。TCP底层并不了解上层业务数据的具体含义,它会根据TCP缓冲区的实际情况进行包的划分,所以在业务上认为,一个完整的包可能会被TCP拆分成多个包进行发送,也有可能把多个小的包封装成一个大的数据包发送,这就是所谓的TCP粘包和拆包问题。粘包和拆包需要应用层程序来解决。

我们采用在请求和响应的头部保存消息体的长度的方式解决粘包和拆包问题。请求和响应的格式如下:

 +--------+----------------+
 | Length |  Content       |
 |  4字节  |  Length个字节   |
 +--------+----------------+

4. 网络通信框架的选择

出于性能的考虑,RPC一般选择异步非阻塞的网络通信方式,JDK自带的NIO网络编程操作繁杂,Netty是一款基于NIO开发的网络通信框架,其对java NIO进行封装对外提供友好的API,并且内置了很多开箱即用的组件,如各种编码解码器。所以我们采用Netty作为RPC服务的网络通信框架。

三、RPC服务端

RPC分为客户端和服务端,它们有一个共同的服务接口API,我们首先定义一个接口 HelloService

public interface HelloService {
    String sayHello(String name);
}

然后服务端需要提供该接口的实现类,然后使用自定义的@RpcService注解标注,该注解扩展自@Component,被其标注的类可以被Spring的容器管理。

@RpcService
public class HelloServiceImp implements HelloService {
    @Override
    public String sayHello(String name) {
        return "Hello " + name;
    }
}
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Component
public @interface RpcService {
    
}

RPC服务器类

我们实现了ApplicationContextAware接口,以便从bean容器中取出@RpcService实现类,存入我们的map容器中。

@Component
@Slf4j
public class RpcServer implements ApplicationContextAware, InitializingBean {
    // RPC服务实现容器
    private Map<String, Object> rpcServices = new HashMap<>();
    @Value("${rpc.server.port}")
    private int port;

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        Map<String, Object> services = applicationContext.getBeansWithAnnotation(RpcService.class);
        for (Map.Entry<String, Object> entry : services.entrySet()) {
            Object bean = entry.getValue();
            Class<?>[] interfaces = bean.getClass().getInterfaces();
            for (Class<?> inter : interfaces) {
                rpcServices.put(inter.getName(),  bean);
            }
        }
        log.info("加载RPC服务数量:{}", rpcServices.size());
    }

    @Override
    public void afterPropertiesSet() {
        start();
    }

    private void start(){
        new Thread(() -> {
            EventLoopGroup boss = new NioEventLoopGroup(1);
            EventLoopGroup worker = new NioEventLoopGroup();
            try {
                ServerBootstrap bootstrap = new ServerBootstrap();
                bootstrap.group(boss, worker)
                        .childHandler(new ChannelInitializer<SocketChannel>() {
                            @Override
                            protected void initChannel(SocketChannel ch) throws Exception {
                                ChannelPipeline pipeline = ch.pipeline();
                                pipeline.addLast(new IdleStateHandler(0, 0, 60));
                                pipeline.addLast(new JsonDecoder());
                                pipeline.addLast(new JsonEncoder());
                                pipeline.addLast(new RpcInboundHandler(rpcServices));
                            }
                        })
                        .channel(NioServerSocketChannel.class);
                ChannelFuture future = bootstrap.bind(port).sync();
                log.info("RPC 服务器启动, 监听端口:" + port);
                future.channel().closeFuture().sync();
            }catch (Exception e){
                e.printStackTrace();
                boss.shutdownGracefully();
                worker.shutdownGracefully();
            }
        }).start();

    }
}

RpcServerInboundHandler 负责处理RPC请求

@Slf4j
public class RpcServerInboundHandler extends ChannelInboundHandlerAdapter {
    private Map<String, Object> rpcServices;

    public RpcServerInboundHandler(Map<String, Object> rpcServices){
        this.rpcServices = rpcServices;
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        log.info("客户端连接成功,{}", ctx.channel().remoteAddress());
    }

    public void channelInactive(ChannelHandlerContext ctx)   {
        log.info("客户端断开连接,{}", ctx.channel().remoteAddress());
        ctx.channel().close();
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg){
        RpcRequest rpcRequest = (RpcRequest) msg;
        log.info("接收到客户端请求, 请求接口:{}, 请求方法:{}", rpcRequest.getClassName(), rpcRequest.getMethodName());
        RpcResponse response = new RpcResponse();
        response.setRequestId(rpcRequest.getRequestId());
        Object result = null;
        try {
            result = this.handleRequest(rpcRequest);
            response.setResult(result);
        } catch (Exception e) {
            e.printStackTrace();
            response.setSuccess(false);
            response.setErrorMessage(e.getMessage());
        }
        log.info("服务器响应:{}", response);
        ctx.writeAndFlush(response);
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        log.info("连接异常");
        ctx.channel().close();
        super.exceptionCaught(ctx, cause);
    }

    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
        if (evt instanceof IdleStateEvent){
            IdleStateEvent event = (IdleStateEvent)evt;
            if (event.state()== IdleState.ALL_IDLE){
                log.info("客户端已超过60秒未读写数据, 关闭连接.{}",ctx.channel().remoteAddress());
                ctx.channel().close();
            }
        }else{
            super.userEventTriggered(ctx,evt);
        }
    }

    private Object handleRequest(RpcRequest rpcRequest) throws Exception{
        Object bean = rpcServices.get(rpcRequest.getClassName());
        if(bean == null){
            throw new RuntimeException("未找到对应的服务: " + rpcRequest.getClassName());
        }
        Method method = bean.getClass().getMethod(rpcRequest.getMethodName(), rpcRequest.getParameterTypes());
        method.setAccessible(true);
        return method.invoke(bean, rpcRequest.getParameters());
    }
}

四、RPC客户端

/**
 * RPC远程调用的客户端
 */
@Slf4j
@Component
public class RpcClient {
    @Value("${rpc.remote.ip}")
    private String remoteIp;

    @Value("${rpc.remote.port}")
    private int port;

    private Bootstrap bootstrap;

    // 储存调用结果
    private final Map<String, SynchronousQueue<RpcResponse>> results = new ConcurrentHashMap<>();

    public RpcClient(){

    }

    @PostConstruct
    public void init(){
        bootstrap = new Bootstrap().remoteAddress(remoteIp, port);
        NioEventLoopGroup worker = new NioEventLoopGroup(1);
        bootstrap.group(worker)
                .channel(NioSocketChannel.class)
                .handler(new ChannelInitializer<SocketChannel>() {
                    @Override
                    protected void initChannel(SocketChannel channel) throws Exception {
                        ChannelPipeline pipeline = channel.pipeline();
                        pipeline.addLast(new IdleStateHandler(0, 0, 10));
                        pipeline.addLast(new JsonEncoder());
                        pipeline.addLast(new JsonDecoder());
                        pipeline.addLast(new RpcClientInboundHandler(results));
                    }
                });
    }

    public RpcResponse send(RpcRequest rpcRequest) {
        RpcResponse rpcResponse = null;
        rpcRequest.setRequestId(UUID.randomUUID().toString());
        Channel channel = null;
        try {
            channel = bootstrap.connect().sync().channel();
            log.info("连接建立, 发送请求:{}", rpcRequest);
            channel.writeAndFlush(rpcRequest);
            SynchronousQueue<RpcResponse> queue = new SynchronousQueue<>();
            results.put(rpcRequest.getRequestId(), queue);
            // 阻塞等待获取响应
            rpcResponse = queue.take();
            results.remove(rpcRequest.getRequestId());
        } catch (InterruptedException e) {
            e.printStackTrace();
        } finally {
            if(channel != null && channel.isActive()){
                channel.close();
            }
        }
        return rpcResponse;
    }
}

RpcClientInboundHandler负责处理服务端的响应

@Slf4j
public class RpcClientInboundHandler extends ChannelInboundHandlerAdapter {
    private Map<String, SynchronousQueue<RpcResponse>> results;

    public RpcClientInboundHandler(Map<String, SynchronousQueue<RpcResponse>> results){
        this.results = results;
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        RpcResponse rpcResponse = (RpcResponse) msg;
        log.info("收到服务器响应:{}", rpcResponse);
        if(!rpcResponse.isSuccess()){
            throw new RuntimeException("调用结果异常,异常信息:" + rpcResponse.getErrorMessage());
        }
        // 取出结果容器,将response放进queue中
        SynchronousQueue<RpcResponse> queue = results.get(rpcResponse.getRequestId());
        queue.put(rpcResponse);
    }

    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
        if (evt instanceof IdleStateEvent){
            IdleStateEvent event = (IdleStateEvent)evt;
            if (event.state() == IdleState.ALL_IDLE){
                log.info("发送心跳包");
                RpcRequest request = new RpcRequest();
                request.setMethodName("heartBeat");
                ctx.channel().writeAndFlush(request);
            }
        }else{
            super.userEventTriggered(ctx, evt);
        }
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause){
        log.info("异常:{}", cause.getMessage());
        ctx.channel().close();
    }
}

接口代理

为了使客户端像调用本地方法一样调用远程服务,我们需要对接口进行动态代理。

代理类实现

@Component
public class RpcProxy implements InvocationHandler {

    @Autowired
    private RpcClient rpcClient;

    @Override
    public Object invoke(Object proxy, Method method, Object[] args){
        RpcRequest rpcRequest = new RpcRequest();
        rpcRequest.setClassName(method.getDeclaringClass().getName());
        rpcRequest.setMethodName(method.getName());
        rpcRequest.setParameters(args);
        rpcRequest.setParameterTypes(method.getParameterTypes());

        RpcResponse rpcResponse = rpcClient.send(rpcRequest);
        return rpcResponse.getResult();
    }
}

实现FactoryBean接口,将生产动态代理类纳入 Spring 容器管理。

public class RpcFactoryBean<T> implements FactoryBean<T> {
    private Class<T> interfaceClass;

    @Autowired
    private RpcProxy rpcProxy;

    public RpcFactoryBean(Class<T> interfaceClass){
        this.interfaceClass = interfaceClass;
    }

    @Override
    public T getObject(){
        return (T) Proxy.newProxyInstance(interfaceClass.getClassLoader(), new Class[]{interfaceClass}, rpcProxy);
    }

    @Override
    public Class<?> getObjectType() {
        return interfaceClass;
    }
}

自定义类路径扫描器,扫描包下的RPC接口,动态生产代理类,纳入 Spring 容器管理

public class RpcScanner extends ClassPathBeanDefinitionScanner {

    public RpcScanner(BeanDefinitionRegistry registry) {
        super(registry);
    }

    @Override
    protected Set<BeanDefinitionHolder> doScan(String... basePackages) {
        Set<BeanDefinitionHolder> beanDefinitionHolders = super.doScan(basePackages);
        for (BeanDefinitionHolder beanDefinitionHolder : beanDefinitionHolders) {
            GenericBeanDefinition beanDefinition = (GenericBeanDefinition)beanDefinitionHolder.getBeanDefinition();
            beanDefinition.getConstructorArgumentValues().addGenericArgumentValue(beanDefinition.getBeanClassName());
            beanDefinition.setBeanClassName(RpcFactoryBean.class.getName());
        }
        return beanDefinitionHolders;
    }

    @Override
    protected boolean isCandidateComponent(MetadataReader metadataReader) throws IOException {
        return true;
    }

    @Override
    protected boolean isCandidateComponent(AnnotatedBeanDefinition beanDefinition) {
        return beanDefinition.getMetadata().isInterface() && beanDefinition.getMetadata().isIndependent();
    }
}
@Component
public class RpcBeanDefinitionRegistryPostProcessor implements BeanDefinitionRegistryPostProcessor {
    @Override
    public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
        RpcScanner rpcScanner = new RpcScanner(registry);
        // 传入RPC接口所在的包名
        rpcScanner.scan("com.ygd.rpc.common.service");
    }

    @Override
    public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
    }
}

JSON编解码器

/**
 * 将 RpcRequest 编码成字节序列发送
 * 消息格式: Length + Content
 * Length使用int存储,标识消息体的长度
 *
 * +--------+----------------+
 * | Length |  Content       |
 * |  4字节 |   Length个字节  |
 * +--------+----------------+
 */
public class JsonEncoder extends MessageToByteEncoder<RpcRequest> {
    @Override
    protected void encode(ChannelHandlerContext ctx, RpcRequest rpcRequest, ByteBuf out){
        byte[] bytes = JSON.toJSONBytes(rpcRequest);
        // 将消息体的长度写入消息头部
        out.writeInt(bytes.length);
        // 写入消息体
        out.writeBytes(bytes);
    }
}
/**
 * 将响应消息解码成 RpcResponse
 */
public class JsonDecoder extends LengthFieldBasedFrameDecoder {

    public JsonDecoder(){
        super(Integer.MAX_VALUE, 0, 4, 0, 4);
    }

    @Override
    protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
        ByteBuf msg = (ByteBuf) super.decode(ctx, in);
        byte[] bytes = new byte[msg.readableBytes()];
        msg.readBytes(bytes);
        RpcResponse rpcResponse = JSON.parseObject(bytes, RpcResponse.class);
        return rpcResponse;
    }
}

测试

我们编写一个Controller进行测试

@RestController
@RequestMapping("/hello")
public class HelloController {
    @Autowired
    private HelloService helloService;
    @GetMapping("/sayHello")
    public String hello(String name){
        return helloService.sayHello(name);
    }
}

通过 PostMan调用 controller 接口 http://localhost:9998/hello/sayHello?name=小明

响应: Hello 小明

总结

本文实现了一个简易的、具有基本概念的RPC,主要涉及的知识点如下:

  • 网络通信及通信协议的编码、解码
  • Java对象的序列化及反序列化
  • 通信链路心跳检测
  • Java反射
  • JDK动态代理

项目完整代码详见:github.com/yinguodong/…