用netty实现一个rpc

405 阅读7分钟

        以前在做项目的时候接触到了rpc,感觉很有意思,但是那个框架用了rabbitmq来进行channel的管理。正好前几天看到了netty,一个高效的JAVA NIO框架,所以萌生了想自己写一个rpc框架的想法。

RPC原理简介

        RPC(Remote Procedure Call)指远程过程调用。它的作用大概可以这么描述一下:B 程序员 程序想要调用 A 程序的某个函数,但是由于 A 与 B 是两个独立的项目,B 不可能直接调用 A 中的任何一个类里的任何一个函数。这时 RPC 就能起到它的作用了。         为了完成 B 程序的需求, A 程序 对 B 程序进行规定,如果 B 想要调用 A 的方法,需要给 A 一个规定的数据格式,然后 A 在本地执行完 B 所想要使用的函数后将结果 封装成一个规定好的数据格式后发送给 B。这样 B 就达到了不拷贝 A 的代码的情况下完成其所需要的业务功能。

rpc调用流程
rpc调用流程

RPC具体实现

通信框架的选择

        一个rpc底层应该支持io/nio,这种实现方法大致有两种,一是通过代码完全有自己实现,但是这种方法对技术要求比较高,而且容易出现隐藏的bug,另一种就是利用现有的开源框架,Netty 是个不错的选择,它是一个利用 Java 的高级网络的能力,隐藏其背后的复杂性而提供一个易于使用的 API 的客户端/服务器框架。它能大大的简化我们的开发流程,使得代码更加牢靠。在这次的RPC中,我们使用 Netty 来作为连接客户端与服务端的桥梁。

数据的序列化与反序列化

        一个好的rpc不应该受到语言的限制,所以client端到server端的数据交换格式应该有一个良好的定义,比如json、xml。现在这方面成熟的框架有很多比如Thrift、Protobuf等等。我们不用自己去定义以及实现一个交换格式,这些成熟的框架都是久经考验的。在本例中由于是抱着学习的目的,本人采用java自带的序列化与反序列化方法。

服务的注册与发现

        client的端想要调用服务端的某个方法,需要得知这个方法的某些信息,而现在问题就来了,得知这个信息的时候是由写 A 程序的人去直接告诉 B 程序的人,还是由 B 程序主动去发现 A 的服务。很明显,第一种方法很不牢靠,若是采用这种方法的话,A服务的每次改动都要通知到 B ,B也要每次根据 A服务的改变,而重写自己的代码。相比之下,第二种方法更显得可行。         其实实现服务的注册与发现的方法也有很多,比如zookeeper,redis等等。大致原理就是,A 服务将自己暴露出的方法信息存在zookeeper或者redis上,每次更改由A主动通知或由 B 去zookeeper或redis上自动去拉取最新的信息。对于zookeeper来说存储方法信息的是一个个固定的节点,对于reids来说就是一个key值。用zookeeper还解决了在分布式的部署方案下,某个服务down机的问题。因为zookeeper与生俱来的容灾能力(比如leader选举),可以确保服务注册表的高可用性。在本例中,我并未实现服务的注册于发现。

client与server

        client端与server端有各自需要处理的发送格式与接受格式。对于client端来说需要封装好自己所要请求的方法信息发送给server端,并等待server端返回的结果。server端则是接收client的请求数据,处理完成后返回给client端结果数据。         其实RPC在调用的时候应该让调用者像调用本地服务一般的去完成业务逻辑。这种实现在java中就应该用代理来实现。

关键代码

数据交换格式定义

client请求格式 利用java自带的序列化方法要继承Serializable方法并且要实现无参构造方法。

package com.example.nettyrpcfirst.netty.entity;

import java.io.Serializable;

/**
 * @auther lichaobao
 * @date 2018/9/21
 * @QQ 1527563274
 */
public class Request implements Serializable {
    private static final long serialVersionUID  = -1L;
    private String clientId;
    private String className;
    private String methodName;
    private Class[] paramterTypes;
    private Object[] parameters;


    public Request(String clientId, String className, String methodName, Class[] paramterTypes, Object[] parameters) {
        this.clientId = clientId;
        this.className = className;
        this.methodName = methodName;
        this.paramterTypes = paramterTypes;
        this.parameters = parameters;
    }

    public Request() {
    }
//getter and setter
}

server 响应数据格式 具体要求与client端相同

package com.example.nettyrpcfirst.netty.entity;

import java.io.Serializable;

/**
 * @auther lichaobao
 * @date 2018/9/21
 * @QQ 1527563274
 */
public class Response implements Serializable {
    private static final long serialVersionUID = -1L;
    private String clientId;
    private Throwable err;
    private Object result;


    public Response() {
    }
// getter and setter
}

注意 clientId字段的设置是为了保证返回的数据是自己想要的。

server 实现

netty不熟悉的可以去官网写写几个例子

package com.example.nettyrpcfirst.netty.server;



/**
 * @auther lichaobao
 * @date 2018/9/21
 * @QQ 1527563274
 */
public class ServerHandler extends SimpleChannelInboundHandler {
    Logger logger = LoggerFactory.getLogger(ServerHandler.class);

    private final Map<String, Object> services;

    public ServerHandler(Map<String, Object> services) {
        this.services = services;
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        logger.info("{} has created a channel",ctx.channel().remoteAddress());
    }


    @Override
    protected void channelRead0(ChannelHandlerContext channelHandlerContext, Object o) throws Exception {
       Runnable r = () ->{
           Request request = (Request) o;
           Response response = new Response();
           response.setClientId(request.getClientId());
           try {
               Object service = services.get(request.getClassName());
               FastClass serviceFastClass = FastClass.create(service.getClass());
               FastMethod serviceFastMethod = serviceFastClass.getMethod(request.getMethodName(), request.getParamterTypes());
               Object result = serviceFastMethod.invoke(service, request.getParameters());
               response.setResult(result);
           }catch (Exception e){
               response.setErr(e);
           }
           channelHandlerContext.writeAndFlush(response).addListener(new ChannelFutureListener() {
               @Override
               public void operationComplete(ChannelFuture channelFuture) throws Exception {
                   logger.info("send response for request: "+request.getClientId());
               }
           });
       };
       Server.submit(r);
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        cause.printStackTrace();
        logger.error("rpc server err occur" + cause.getMessage()+" | "+ctx.channel().remoteAddress());
        ctx.close();
    }
}

最主要的就是channelRead0方法,这里定义了在接收到客户端的数据后如何去调用本地方法,具体是用cglib代理完成。 server具体代码

package com.example.nettyrpcfirst.netty.server;



/**
 * @auther lichaobao
 * @date 2018/9/21
 * @QQ 1527563274
 */
@Configuration
public class Server implements BeanNameAware, BeanFactoryAware, ApplicationContextAware,InitializingBean {
    private static final org.slf4j.Logger logger = LoggerFactory.getLogger(Server.class);

    private Map<String,Object> services = new ConcurrentHashMap<>();

    private static ExecutorService threadPoolExecutor;

    public Server(){
    }

    /**
     * 启动netty server
     * @throws Exception
     */
    @Override
    public void afterPropertiesSet() throws Exception {
        logger.info("afterPropertiesSet");
        EventLoopGroup bossGroup = new NioEventLoopGroup();
        EventLoopGroup workerGroup = new NioEventLoopGroup();
        try{
            ServerBootstrap b = new ServerBootstrap();
            b.group(bossGroup,workerGroup)
                    .channel(NioServerSocketChannel.class)
                    .childHandler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        protected void initChannel(SocketChannel socketChannel) throws Exception {
                            socketChannel.pipeline()
                                    .addLast(new ObjectDecoder(1024,ClassResolvers.cacheDisabled(this.getClass().getClassLoader())))
                                    .addLast(new ObjectEncoder())
                                    .addLast(new ServerHandler(services));
                        }
                    })
                    .option(ChannelOption.SO_BACKLOG,128)
                    .childOption(ChannelOption.SO_KEEPALIVE,true);
            ChannelFuture future = b.bind(8080).sync();
            future.channel().closeFuture().sync();
        }catch (Exception e){
            e.printStackTrace();
            logger.error("an error occur ----------->"+e.getMessage());
        }finally {
            bossGroup.shutdownGracefully();
            workerGroup.shutdownGracefully();
        }
    }

    /**
     * 通过扫描所有带有@RPCServer注解的类进行注册
     * @param applicationContext
     * @throws BeansException
     */
    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
            logger.info("setApplicationContext");
            Map<String,Object> serviceBeanMap = applicationContext.getBeansWithAnnotation(RPCServer.class);
            if(!serviceBeanMap.isEmpty()){
                for (Object service :serviceBeanMap.values()){
                    String interfaceName = service.getClass().getAnnotation(RPCServer.class).value().getName();
                    logger.info("RPCService:  {}" , interfaceName);
                    this.services.put(interfaceName,service);
                }
            }
    }
    public static void submit(Runnable task){
        if(threadPoolExecutor == null){
            synchronized (RPCServer.class){
                if(threadPoolExecutor == null){
                    threadPoolExecutor = Executors.newFixedThreadPool(16);
                }
            }
        }
        threadPoolExecutor.submit(task);
    }

    @Override
    public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
        logger.info("setBeanFactory()");
    }

    @Override
    public void setBeanName(String s) {
        logger.info("setBeanName() {}", s);
    }
}

RPCServer 自定义注解

package com.example.nettyrpcfirst.netty.annoations;


/**
 * @auther lichaobao
 * @date 2018/9/21
 * @QQ 1527563274
 */
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Component
public @interface RPCServer {
    Class<?> value();
}

client实现

client代理类

package com.example.nettyrpcfirst.netty.client;


/**
 * @auther lichaobao
 * @date 2018/9/21
 * @QQ 1527563274
 */
public class ClientProxy {
    @SuppressWarnings("unchecked")
    public <T> T create(Class<?> interfaceClass){

        return (T)Proxy.newProxyInstance(interfaceClass.getClassLoader(),
                new Class<?>[]{interfaceClass},
                new InvocationHandler() {
                    @Override
                    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
                        Request request = new Request();
                        request.setClientId(UUID.randomUUID().toString());
                        request.setClassName(method.getDeclaringClass().getName());
                        request.setMethodName(method.getName());
                        request.setParamterTypes(method.getParameterTypes());
                        request.setParameters(args);
                        Client client = new Client("127.0.0.1",8080);
                        Response response = client.send(request);
                        if(response.getErr()!=null){
                            throw response.getErr();
                        }else{
                            return response.getResult();
                        }
                    }
                });
    }
}

client 与server 连接发送数据并等待数据返回

package com.example.nettyrpcfirst.netty.client;



/**
 * @auther lichaobao
 * @date 2018/9/21
 * @QQ 1527563274
 */
public class Client extends SimpleChannelInboundHandler<Response> {
    private static org.slf4j.Logger logger = LoggerFactory.getLogger(Client.class);

    private Response response;
    private final static Object obj = new Object();
    private String host;
    private int port;
    ChannelFuture future;
    public Client(String host,int port){
        this.host = host;
        this.port = port;
    }

    /**
     * 接收到消息后唤醒线程
     * @param channelHandlerContext
     * @param response
     * @throws Exception
     */

    @Override
    protected void channelRead0(ChannelHandlerContext channelHandlerContext, Response response) throws Exception {
        this.response = response;
        synchronized (obj){
            obj.notifyAll();
        }
    }
    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        logger.error("client caught exception", cause);
        ctx.close();
    }

    /**
     * 连接server端channel,发送完数据后锁定线程,等待数据返回
     * @param request
     * @return
     * @throws Exception
     */
    public Response send(Request request) throws Exception{
        EventLoopGroup eventLoopGroup = new NioEventLoopGroup();
        try{
            Bootstrap b = new Bootstrap();
            b.group(eventLoopGroup)
                    .channel(NioSocketChannel.class)
                    .handler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        protected void initChannel(SocketChannel socketChannel) throws Exception {
                            socketChannel.pipeline()
                                    .addLast(new ObjectDecoder(1024,ClassResolvers.cacheDisabled(this.getClass().getClassLoader())))
                                    .addLast(new ObjectEncoder())
                                    .addLast(Client.this);
                        }
                    })
                    .option(ChannelOption.SO_KEEPALIVE,true);
            ChannelFuture future = b.connect(host,port).sync();
            future.channel().writeAndFlush(request).sync();
            System.out.println("2   "+Thread.currentThread().getName());
            synchronized (obj){
                System.out.println("1111111111111111");
                obj.wait();
            }
            if(response != null){
                System.out.println("3333333333333");
            }
            return response;
        }finally {
            if(future!=null){
                future.channel().closeFuture().sync();
            }
            eventLoopGroup.shutdownGracefully();
        }
    }
}

测试

package com.example.nettyrpcfirst.netty.client;

import com.example.nettyrpcfirst.netty.entity.TestService;

/**
 * @auther lichaobao
 * @date 2018/9/21
 * @QQ 1527563274
 */
public class TestMain {
    public static void main(String[] args){
        TestService testService = new ClientProxy().create(TestService.class);
        String result = testService.play();
        System.out.println("收到消息 ------------> "+result);
    }
}

不出意外的话 控制台会成功打印