基于Java NIO实现一个RPC框架

428 阅读4分钟

RPC(Remote Procedure Call,远程过程调用)框架是一种用于实现分布式应用的技术。它允许应用程序通过网络调用远程计算机上的服务,就好像这些服务是本地的一样。RPC 框架主要包括以下几个组件:

  1. 通信协议:定义客户端和服务端之间通信的格式和规范,包括数据格式、传输协议、序列化方式等。

  2. 服务注册:负责将服务注册到注册中心,以便客户端可以发现和调用服务。

  3. 序列化和反序列化:将对象转换为二进制数据或其他格式的数据,以便在网络上传输。

  4. 负载均衡:根据一定的负载均衡策略,将客户端请求分散到多个服务提供者上,以实现高可用和高并发。

  5. 远程调用:通过网络调用远程服务,实现应用程序的分布式部署。

  6. 容错处理:当服务提供者出现故障或网络异常时,能够进行容错处理,保证服务的可用性和稳定性。

常见的 RPC 框架包括 Dubbo、gRPC、Thrift 等。这些框架都提供了完整的 RPC 解决方案,可以帮助开发者快速构建分布式应用,并提供了丰富的功能和扩展点,以满足不同应用场景的需求。

下面是一个简单的基于 Java NIO 实现的 RPC 框架示例代码:

User实体:

public class User {
    private int id;
    private String name;
    public User(int id, String name) {
        this.id = id;
        this.name = name;
    }
    public int getId() {
        return id;
    }
    public String getName() {
        return name;
    }
    @Override
    public String toString() {
        return "User{" +
                "id=" + id +
                ", name='" + name + ''' +
                '}';
    }
}

服务接口:

public interface UserService {
    User getUser(int id);
    List<User> getAllUsers();
}
public interface Service {
    Response callMethod(String methodName, String[] args);
}

服务实现:

public class UserServiceImpl implements UserService, Service {
    private List<User> userList;
    public UserServiceImpl() {
        userList = new ArrayList<>();
        userList.add(new User(1, "Alice"));
        userList.add(new User(2, "Bob"));
        userList.add(new User(3, "Charlie"));
    }
    @Override
    public User getUser(int id) {
        for (User user : userList) {
            if (user.getId() == id) {
                return user;
            }
        }
        return null;
    }
    @Override
    public List<User> getAllUsers() {
        return userList;
    }

    @Override
    public Response callMethod(String methodName, String[] args) {
        switch (methodName) {
            case "getUser": {
                int id = Integer.parseInt(args[0]);
                User user = getUser(id);
                if (user == null) {
                    return new Response("No such user: " + id);
                } else {
                    return new Response(user.toString());
                }
            }
            case "getAllUsers": {
                List<User> userList = getAllUsers();
                StringBuilder sb = new StringBuilder();
                for (User user : userList) {
                    sb.append(user.toString()).append("\n");
                }
                return new Response(sb.toString());
            }
            default:
                return new Response("No such method: " + methodName);
        }
    }
}

返回体:

@Data
public class Response {
    private String status;
    private String data;
    public Response(String data) {
        this.status = "200 OK";
        this.data = data;
    }
}

服务器类:

public class RpcServer {
    // 缓冲区大小
    private static final int BUFFER_SIZE = 1024;
    // 服务端Socket通道
    private ServerSocketChannel serverSocketChannel;
    // 选择器
    private Selector selector;
    // 注册的服务
    private Map<String, Service> services = new HashMap<>();
    // 构造函数,初始化ServerSocketChannel和Selector
    public RpcServer(int port) throws IOException {
        // 打开ServerSocketChannel
        serverSocketChannel = ServerSocketChannel.open();
        // 绑定端口号
        serverSocketChannel.socket().bind(new InetSocketAddress(port));
        // 设置为非阻塞模式
        serverSocketChannel.configureBlocking(false);
        // 打开Selector
        selector = Selector.open();
        // 将ServerSocketChannel注册到Selector上,监听ACCEPT事件
        serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT);
    }
    // 注册服务
    public void registerService(String serviceName, Service service) {
        services.put(serviceName, service);
    }
    // 启动服务
    public void start() throws IOException {
        while (true) {
            // 阻塞等待事件
            selector.select();
            // 获取所有的事件
            Iterator<SelectionKey> keyIterator = selector.selectedKeys().iterator();
            while (keyIterator.hasNext()) {
                SelectionKey key = keyIterator.next();
                // 处理ACCEPT事件
                if (key.isAcceptable()) {
                    handleAccept(key);
                    // 处理READ事件
                } else if (key.isReadable()) {
                    handleRead(key);
                }
                // 处理完事件后移除该事件
                keyIterator.remove();
            }
        }
    }
    // 处理ACCEPT事件
    private void handleAccept(SelectionKey key) throws IOException {
        // 获取ServerSocketChannel
        ServerSocketChannel serverChannel = (ServerSocketChannel) key.channel();
        // 接收客户端连接
        SocketChannel socketChannel = serverChannel.accept();
        // 设置为非阻塞模式
        socketChannel.configureBlocking(false);
        // 将SocketChannel注册到Selector上,监听READ事件
        socketChannel.register(selector, SelectionKey.OP_READ);
    }
    // 处理READ事件
    private void handleRead(SelectionKey key) throws IOException {
        // 获取SocketChannel
        SocketChannel socketChannel = (SocketChannel) key.channel();
        // 分配缓冲区
        ByteBuffer buffer = ByteBuffer.allocate(BUFFER_SIZE);
        // 读取数据
        int numBytes = socketChannel.read(buffer);
        // 如果读取到末尾,关闭连接
        if (numBytes == -1) {
            key.cancel();
            socketChannel.close();
            return;
        }
        // 解析请求数据
        String requestData = new String(buffer.array(), 0, numBytes);
        String[] requestArray = requestData.split("\|");
        if (requestArray.length != 3) {
            // 如果请求格式不正确,发送错误响应
            sendError(socketChannel, "Invalid request format");
            return;
        }
        // 获取服务名、方法名和参数
        String serviceName = requestArray[0];
        String methodName = requestArray[1];
        String[] args = requestArray[2].split(",");
        // 获取对应的服务
        Service service = services.get(serviceName);
        if (service == null) {
            // 如果找不到对应的服务,发送错误响应
            sendError(socketChannel, "No such service: " + serviceName);
            return;
        }
        // 调用对应的方法,获取响应
        Response response = service.callMethod(methodName, args);
        String responseData = response.getData();
        ByteBuffer responseBuffer = ByteBuffer.allocate(responseData.getBytes().length);
        responseBuffer.put(responseData.getBytes());
        responseBuffer.flip();
        // 发送响应
        socketChannel.write(responseBuffer);
    }
    // 发送错误响应
    private void sendError(SocketChannel socketChannel, String errorMessage) throws IOException {
        Response response = new Response(errorMessage);
        String responseData = response.getData();
        ByteBuffer responseBuffer = ByteBuffer.allocate(responseData.getBytes().length);
        responseBuffer.put(responseData.getBytes());
        responseBuffer.flip();
        socketChannel.write(responseBuffer);
    }
    public static void main(String[] args) throws IOException {
        // 创建RpcServer实例
        RpcServer rpcServer = new RpcServer(8080);
        // 注册服务
        rpcServer.registerService("userService", new UserServiceImpl());
        // 启动服务
        rpcServer.start();
    }
}

客户端类:

public class RpcClient {
    private static final int BUFFER_SIZE = 1024;
    private SocketChannel socketChannel;
    public RpcClient(String host, int port) throws IOException {
        // 首先创建一个SocketChannel实例
        socketChannel = SocketChannel.open();
        // 调用SocketChannel的connect方法连接到指定的服务器地址和端口号
        socketChannel.connect(new InetSocketAddress(host, port));
        // 将SocketChannel设置为非阻塞模式,以便在读写数据时不会阻塞线程。
        socketChannel.configureBlocking(false);
    }

    public Response callService(String serviceName, String methodName, String[] args) throws IOException, InterruptedException {
        // 构造请求数据
        String requestData = serviceName + "|" + methodName + "|" + String.join(",", args);
        // 分配缓冲区
        ByteBuffer buffer = ByteBuffer.allocate(BUFFER_SIZE);
        // 将请求数据写入缓冲区
        buffer.put(requestData.getBytes());
        // 切换为读模式
        buffer.flip();
        // 将缓冲区的数据写入到 socket 通道中
        socketChannel.write(buffer);
        // 清空缓冲区
        buffer.clear();
        // 从 socket 通道中读取响应数据
        int numBytes;
        while (true) {
            numBytes = socketChannel.read(buffer);
            if (numBytes > 0) {
                break;
            } else if (numBytes == 0) {
                continue;
            } else {
                // 发生了错误,抛出异常
                throw new IOException("Error reading data from server");
            }
        }
        // 将缓冲区的数据转换为字符串
        String responseData = new String(buffer.array(), 0, numBytes);
        // 返回响应对象
        return new Response(responseData);
    }

    public static void main(String[] args) throws IOException, InterruptedException {
        RpcClient rpcClient = new RpcClient("localhost", 8080);
        Response response = rpcClient.callService("userService", "getUser", new String[]{"1"});
        System.out.println(response.getData());
    }
}

以上代码可正常运行,有疑问可以交流~~

欢迎关注公众号:程序员的思考与落地

公众号提供大量实践案例,Java入门者不容错过哦,可以陪聊!!