手撸一个rpc框架(实现基本的socket通信)

107 阅读5分钟

创建客户端类

要想实现一个rpc客户端,它应该实现什么方法,具备什么功能。 回到这张图

image.png

1 客户端需要去连接服务端

2 客户端需要能够发送客户端句柄

3 客户端需要接受服务端的句柄

按照这个思路,需要在Rpc_client包下创建一个接口Client,实现上述的三个方法。在Rpc_common包下定义客户端的句柄和服务端的句柄,分别是类Res和类Invocation,这两个对象代表服务端返回和客户端传输类。

Client接口

package com.suancaiyu.rpc_client.protocal;

import com.suancaiyu.rpc_common.common.model.URL;
import com.suancaiyu.rpc_common.common.model.correspondence.Base;

import java.io.IOException;

public interface Client {
    //连接客户端
    public void connect(URL url);
    //接受服务端句柄
    public Object receive() throws IOException, ClassNotFoundException;
    //发送客户端句柄
    public <T extends Base> Object send(URL url, T message) throws IOException;
}

Res类

package com.suancaiyu.rpc_common.common.model;

import java.io.Serializable;

public class Res implements Serializable {
    private static final long serialVersionUID=7340619444158003434L;

    private Throwable throwable;
    private Object resData;

    public Res(Throwable throwable, Object resData) {
        this.throwable = throwable;
        this.resData = resData;
    }

    @Override
    public String toString() {
        return "Res{" +
                "throwable=" + throwable +
                ", resData=" + resData +
                '}';
    }

    public Throwable getThrowable() {
        return throwable;
    }

    public void setThrowable(Throwable throwable) {
        this.throwable = throwable;
    }

    public Object getResData() {
        return resData;
    }

    public void setResData(Object resData) {
        this.resData = resData;
    }

    public Res(){

    }
}

Invocation类

package com.suancaiyu.rpc_common.common.model.correspondence;

import java.io.Serializable;
import java.util.Arrays;

public class Invocation extends Base implements Serializable  {

    private String interfaceName;
    private String methodName;
    private String version;

    public String getVersion() {
        return version;
    }

    public void setVersion(String version) {
        this.version = version;
    }

    @Override
    public String toString() {
        return "Invocation{" +
                "interfaceName='" + interfaceName + ''' +
                ", methodName='" + methodName + ''' +
                ", version='" + version + ''' +
                ", parameterTypes=" + Arrays.toString(parameterTypes) +
                ", parameters=" + Arrays.toString(parameters) +
                '}';
    }


    private Class[] parameterTypes;//参数类型
    private Object[] parameters;

    public String getInterfaceName() {
        return interfaceName;
    }

    public void setInterfaceName(String interfaceName) {
        this.interfaceName = interfaceName;
    }

    public String getMethodName() {
        return methodName;
    }

    public void setMethodName(String methodName) {
        this.methodName = methodName;
    }

    public Class[] getParameterTypes() {
        return parameterTypes;
    }

    public void setParameterTypes(Class[] parameterTypes) {
        this.parameterTypes = parameterTypes;
    }

    public Object[] getParameters() {
        return parameters;
    }

    public void setParameters(Object[] parameters) {
        this.parameters = parameters;
    }

    public Invocation(String interfaceName, String methodName, Class[] parameterTypes, Object[] parameters,String type,String version) {
        super(type);
        this.interfaceName = interfaceName;
        this.methodName = methodName;
        this.parameterTypes = parameterTypes;
        this.parameters = parameters;
        this.version=version;
    }
}

创建Client实现类

Rpc_client

package com.suancaiyu.rpc_client.protocal;

import com.suancaiyu.rpc_common.common.model.Res;
import com.suancaiyu.rpc_common.common.model.URL;
import com.suancaiyu.rpc_common.common.model.correspondence.Base;
import com.suancaiyu.rpc_common.exception.ConnectException;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.net.Socket;

public class RpcClient implements Client {



    private static Socket socket=null;
    private static ObjectOutputStream oos=null;
    private static ObjectInputStream ois=null;
    /**
     * 使用Java序列化方式
     * @param url
     * @param message
     * @return
     */
    // TODO: 2023/6/6 优化好异常处理
    @Override
    public <T extends Base> Object send(URL url, T message) throws ConnectException, IOException {

        connect(url);
        //获取输出流写入代码
        try{
            oos=new ObjectOutputStream(socket.getOutputStream());
            oos.writeObject(message);
            oos.flush();
            //获取输入流
            return receive();
        }catch (IOException e){
            throw new ConnectException(e.getMessage());
        } catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
        finally {
            socket.close();
            ois.close();
            oos.close();
        }
    
    }
    @Override
    public void connect(URL url) {
        try {
            socket=new Socket(url.getHostname(),url.getPort());
        } catch (IOException e) {
            throw new ConnectException(e.getMessage());
        }
    }
    @Override
    public Object receive() throws IOException, ClassNotFoundException {
        ois=new ObjectInputStream(socket.getInputStream());
         Res res = (Res) ois.readObject();
         return res.getResData();
    }

}

ConnectException异常类是我自定义的异常类,这里先不贴出(还没把异常处理完善好)

服务端句柄和客户端句柄都必须实现Serializable接口,以便于序列化。因为客户端句柄分很多种类,比如服务发现使用一种句柄,服务调用使用一种句柄,我使用了一个父类(Base)来定义服务端句柄类,Invocation类是它的子类,代表服务调用句柄类。连接那个服务端需要服务端的地址和端口,封装为Url类

Base

package com.suancaiyu.rpc_common.common.model.correspondence;

import java.io.Serializable;

public class Base implements Serializable {
    private String type;

    public Base(String type) {
        this.type=type;
    }

    public String getType() {
        return type;
    }

    @Override
    public String toString() {
        return "Base{" +
                "type='" + type + ''' +
                '}';
    }

    public void setType(String type) {
        this.type = type;
    }

    public Base(){

    }
}

Url

package com.suancaiyu.rpc_common.common.model;

import java.io.Serializable;

public class URL implements Serializable {
    private String hostname;
    private Integer port;

    public URL(String hostname, Integer port) {
        this.hostname = hostname;
        this.port = port;
    }

    public String getHostname() {
        return hostname;
    }

    public void setHostname(String hostname) {
        this.hostname = hostname;
    }

    public Integer getPort() {
        return port;
    }

    public void setPort(Integer port) {
        this.port = port;
    }
}

实现Rpc_client类需要知道socket编程基础和Java序列化方式。下面代码我会实现一个简单的socket连接。

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.net.Socket;

public class SocketExample {
    public static void main(String[] args) {
        String hostname = "localhost"; // 连接的主机名
        int port = 8080; // 连接的端口号

        try {
            // 创建socket对象并连接到远程主机
            Socket socket = new Socket(hostname, port);

            // 获取对象输入流和对象输出流
               ObjectOutputStream outputStream = new ObjectOutputStream(socket.getOutputStream());
            ObjectInputStream inputStream = new ObjectInputStream(socket.getInputStream());

            // 发送对象到服务器
            Person person = new Person("John", 25);
            outputStream.writeObject(person);

            // 从服务器接收对象
            Person response = (Person) inputStream.readObject();
            System.out.println("服务器响应: " + response);

            // 关闭连接
            socket.close();
        } catch (IOException | ClassNotFoundException e) {
            e.printStackTrace();
        }
    }
}

实现服务端

一样的,要想实现服务端,它需要去实现什么功能

1 监听指定端口,绑定ServerSocket对象,开启服务。

2 接受客户端句柄,返回调用服务端句柄。

这里我就不创建接口来规范实现类了。

package com.suancaiyu.rpc_server.protocal;

import com.suancaiyu.rpc_common.common.model.BusinessName;
import com.suancaiyu.rpc_common.common.model.Res;
import com.suancaiyu.rpc_common.common.model.correspondence.Base;
import com.suancaiyu.rpc_common.common.model.correspondence.Invocation;
import com.suancaiyu.rpc_common.reader.YamlReader;
import com.suancaiyu.rpc_server.register.OnlineRegister;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.logging.Logger;

public class RpcServer {
    private static ServerSocket serverSocket=null;
    private static ObjectOutputStream oos=null;
    private static ObjectInputStream ois=null;
    private static Socket socket=null;

    private static final int port= Integer.parseInt(YamlReader.getRpcConfig().getPort());

    private final ExecutorService executorService = Executors.newFixedThreadPool(20); // 创建一个固定大小为10的线程池

    private final Logger logger=Logger.getLogger(RpcServer.class.getName());
    public void start() throws IOException {
        try {
            try {
                //注册服务进注册中心
                startRegister();
                //开启服务
                startHttpServe();
                 //死循环用于反复接受客户端句柄
                while (true){
                    socket = serverSocket.accept();
                    logger.info("客户端已经连接");
                    //接收到句柄后使用线程池去处理句柄并且将返回句柄写入输入流返回
                    executorService.execute(()->{
                        try {
                            Handler();
                        } catch (IOException | ClassNotFoundException | InvocationTargetException |
                                 NoSuchMethodException | InstantiationException | IllegalAccessException e) {
                            throw new RuntimeException(e);
                        }
                    });
                }
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            } finally {
                executorService.shutdown();
                ois.close();
                oos.close();
                serverSocket.close();
                socket.close();
            }

        } catch (Exception e) {
            logger.info(e.getMessage());
            throw new IOException(e.getMessage());
        }
    }


    private void startHttpServe() throws IOException {
        serverSocket=new ServerSocket(port);
    }
    private void startRegister(){
        OnlineRegister.scanService();
    }

    private void Handler() throws IOException, ClassNotFoundException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException {

        ois=new ObjectInputStream(socket.getInputStream());
        Base base = (Base) ois.readObject();
        System.out.println(base.getType());
        String type = base.getType();
        if (type.equals(BusinessName.INVOKESERVICE)){
            InvokeHandler(base);
        }
        if (type.equals(BusinessName.DISCOVERYSERVICE)){
            DiscoveryHandler(base);
        }
    }

    private void DiscoveryHandler(Base base) {
    }

    private void InvokeHandler(Base base) throws NoSuchMethodException, InstantiationException, IllegalAccessException, InvocationTargetException, IOException, ClassNotFoundException {
        Invocation invocation=(Invocation) base;
        String interfaceName = invocation.getInterfaceName();
        Class<?> service = OnlineRegister.get(interfaceName, invocation.getVersion());
        Method serviceMethod = service.getMethod(invocation.getMethodName(), invocation.getParameterTypes());
        Object serviceRes = serviceMethod.invoke(service.newInstance(), invocation.getParameters());
        logger.info("返回结果为"+serviceRes);
        Res res=new Res(null,serviceRes);
        oos=new ObjectOutputStream(socket.getOutputStream());
        oos.writeObject(res);
        oos.flush();
    }
}

因为我现在手上的代码是最后完成的版本,包括了服务注册功能和线程池处理句柄还有读取用户配置的操作,大家听个思路即可,主要操作都在start方法内部。

1 注册服务进注册中心

2 监听端口,开启服务

3 死循环反复接受客户端句柄

4 接收到句柄后使用线程池处理句柄并且写入输入流对象并且返回

总结

具体的调用业务流程就是

1 客户端建立socket连接。

2 客户端封装并且序列化Invocation对象,然后传输到服务端。

3 服务端开启服务。

4 服务端接收到句柄后反序列化成Invocation对象,解析这个序列化对象,知道了客户端需要调用那个方法。

5 服务端处理完这次请求,得到服务端句柄。

6 服务端序列化服务端句柄,并且传输到客户端。

7 客户端反序列化句柄,得到调用结果。

---未完待续

完整代码地址 派大星的海绵裤/rpc_suancaiyu (gitee.com)