手把手带你0到1摸透RPC框架轮子项目-day3注册中心

323 阅读12分钟

原视频参考地址:[全网最全的手写RPC教程] 手摸手教你写一个RPC-架构-设计-落地实现_哔哩哔哩_bilibili

项目地址:kkoneone11/kkoneoneRPC-master (github.com)

觉得对你有帮助的帮忙文章给个like和项目给个stars呀!!!

因为我的设计是从顶到底去写代码,因此这一部分的代码量会比往后的多,大家可以分几天去编写

注册中心

服务注册/发现过程,配置文件

用于提供服务注册以及服务发现功能。

提供方将服务提供给注册中心

调用方进行服务发现时给具体的接口创建代理对象,封装逻辑。服务发现(依赖注入)

代码部分

服务注册

@RpcReference 服务调用方注解 用作服务发现或者服务注册

package org.kkoneone.rpc.annotation;
​
import org.kkoneone.rpc.common.constants.FaultTolerantRules;
import org.kkoneone.rpc.common.constants.LoadBalancerRules;
​
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
​
/**
 * 服务调用方注解 表明接口要被服务发现或者服务注册使用
 * @Author:kkoneone11
 * @name:RpcReference
 * @Date:2023/12/2 12:59
 */
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.FIELD)
public @interface RpcReference {
    /**
     * 版本
     * @return
     */
    String serviceVersion() default "1.0";
​
    /**
     * 超时时间
     * @return
     */
    long timeout() default 5000;
​
    /**
     * 可选的负载均衡:consistentHash,roundRobin...
     * {@link org.kkoneone.rpc.common.constants.LoadBalancerRules}
     * @return
     */
    String loadBalancer() default LoadBalancerRules.RoundRobin;
​
    /**可选的容错策略:failover,failFast,failsafe...
     * {@link org.kkoneone.rpc.common.constants.FaultTolerantRules}
     * @return
     */
    String faultTolerant() default FaultTolerantRules.FailFast;
​
    /**
     * 重试次数
     * @return
     */
    long retryCount() default 3;
}
​

@RpcService服务提供方注解

package org.kkoneone.rpc.annotation;
​
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
​
/**
 * 服务提供方
 * @Author:kkoneone11
 * @name:RpcService
 * @Date:2023/12/2 23:16
 */
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
public @interface RpcService {
    /**
     * 指定实现方,默认为实现接口中第一个
     * @return
     */
    Class<?> serviceInterface() default void.class;
​
    /**
     * 版本
     * @return
     */
    String serviceVersion() default "1.0";
}
​

LoadBalancerRules负载均衡策略

package org.kkoneone.rpc.common.constants;
​
/**
 * 负载均衡策略
 * @Author:kkoneone11
 * @name:LoadBalancerRules
 * @Date:2023/12/2 13:04
 */
public interface LoadBalancerRules {
​
    String ConsistentHash = "consistentHash";
    String RoundRobin = "roundRobin";
}
​

PropertiesUtils属性工具类

package org.kkoneone.rpc.utils;
​
import org.kkoneone.rpc.annotation.PropertiesField;
import org.kkoneone.rpc.annotation.PropertiesPrefix;
import org.springframework.core.env.Environment;
​
import java.lang.reflect.Field;
​
/**
 * @Author:kkoneone11
 * @name:PropertiesUtils
 * @Date:2023/12/2 13:16
 */
public class PropertiesUtils {
    /**
     * 根据对象中的配置匹配配置文件
     * @param o 对应类
     * @param environment 配置参数
     */
    public static void init(Object o, Environment environment){
        //通过反射获取对应类
        final Class<?> aClass = o.getClass();
        //获取类的注解前缀
        PropertiesPrefix prefixAnnotation = aClass.getAnnotation(PropertiesPrefix.class);
        if(prefixAnnotation == null){
            throw new NullPointerException(aClass + "@PropertiesPrefix 不存在");
        }
        String prefix = prefixAnnotation.value();
        // 前缀参数矫正
        if (!prefix.contains(".")){
            prefix += ".";
        }
        //遍历对象中的字段
        for(Field field : aClass.getDeclaredFields()){
            final PropertiesField fieldAnnotation = field.getAnnotation(PropertiesField.class);
            if(fieldAnnotation == null) continue;
            String fieldValue = fieldAnnotation.value();
            if(fieldValue == null || fieldValue.equals("")){
                fieldValue = convertToHyphenCase(field.getName());
            }
            try{
                //当字段访问权限为private时候设置true为告诉jvm我想访问
                field.setAccessible(true);
                final Class<?> type = field.getType();
                //拦截对应类对其塞入额外的参数
                final Object value = PropertyUtil.handle(environment, prefix + fieldValue, type);
                if(value == null)continue;
                //填充字段
                field.set(o,value);
            }catch (IllegalAccessException e){
                e.printStackTrace();
            }
            field.setAccessible(false);
        }
​
    }
​
    /**
     * 将输入字符串转换为连字符如"HelloWorld" -> "-hello-world"
     * @param input
     * @return
     */
    public static String convertToHyphenCase(String input){
        StringBuilder output = new StringBuilder();
        for(int i = 0; i < input.length(); i++){
            char c = input.charAt(i);
            if (Character.isUpperCase(c)) {
                output.append('-');
                output.append(Character.toLowerCase(c));
            } else {
                output.append(c);
            }
        }
        return output.toString();
    }
​
​
}
​

PropertyUtil

package org.kkoneone.rpc.utils;
​
import lombok.SneakyThrows;
import org.springframework.core.env.Environment;
import org.springframework.core.env.PropertyResolver;
​
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
​
/**
 * 属性工具类
 * @Author:kkoneone11
 * @name:PropertyUtil
 * @Date:2023/12/2 16:58
 */
public class PropertyUtil {
​
    private static int springBootVersion = 1;
​
    static {
        try {
            Class.forName("org.springframework.boot.bind.RelaxedPropertyResolver");
        } catch (ClassNotFoundException e) {
            springBootVersion = 2;
        }
    }
​
    /**
     * Spring Boot 1.x is compatible with Spring Boot 2.x by Using Java Reflect.
     * @param environment : the environment context
     * @param prefix : the prefix part of property key
     * @param targetClass : the target class type of result
     * @param <T> : refer to @param targetClass
     * @return T
     */
    @SuppressWarnings("unchecked")
    /**
     * 可以通过拦截器拦截对应的类进行塞入额外拓展的参数
     */
    public static <T> T handle(final Environment environment, final String prefix, final Class<T> targetClass) {
        switch (springBootVersion) {
            case 1:
                return (T) v1(environment, prefix);
            default:
                return (T) v2(environment, prefix, targetClass);
        }
    }
​
    private static Object v1(final Environment environment, final String prefix) {
        try {
            Class<?> resolverClass = Class.forName("org.springframework.boot.bind.RelaxedPropertyResolver");
            Constructor<?> resolverConstructor = resolverClass.getDeclaredConstructor(PropertyResolver.class);
            Method getSubPropertiesMethod = resolverClass.getDeclaredMethod("getSubProperties", String.class);
            Object resolverObject = resolverConstructor.newInstance(environment);
            String prefixParam = prefix.endsWith(".") ? prefix : prefix + ".";
            return getSubPropertiesMethod.invoke(resolverObject, prefixParam);
        } catch (final ClassNotFoundException | NoSuchMethodException | SecurityException | InstantiationException
                | IllegalAccessException | IllegalArgumentException | InvocationTargetException ex) {
            throw new RuntimeException(ex.getMessage(), ex);
        }
    }
​
    @SneakyThrows
    private static Object v2(final Environment environment, final String prefix, final Class<?> targetClass) {
        try {
            Class<?> binderClass = Class.forName("org.springframework.boot.context.properties.bind.Binder");
            Method getMethod = binderClass.getDeclaredMethod("get", Environment.class);
            Method bindMethod = binderClass.getDeclaredMethod("bind", String.class, Class.class);
            Object binderObject = getMethod.invoke(null, environment);
            String prefixParam = prefix.endsWith(".") ? prefix.substring(0, prefix.length() - 1) : prefix;
            Object bindResultObject = bindMethod.invoke(binderObject, prefixParam, targetClass);
            Method resultGetMethod = bindResultObject.getClass().getDeclaredMethod("get");
            return resultGetMethod.invoke(bindResultObject);
        } catch (final ClassNotFoundException | NoSuchMethodException | SecurityException | IllegalAccessException
                | IllegalArgumentException | InvocationTargetException ex) {
            return null;
        }
    }
}
​

RegistryFactory注册工厂

package org.kkoneone.rpc.registry;
​
import org.kkoneone.rpc.spi.ExtensionLoader;
​
/**
 * 注册工厂
 * @Author:kkoneone11
 * @name:RegistryFactory
 * @Date:2023/12/3 9:45
 */
public class RegistryFactory {
​
    //都从这个工厂中获取对应的注册bean
    public static RegistryService get(String registryService) throws Exception {
        return ExtensionLoader.getInstance().get(registryService);
    }
​
    /**
     * 加载SPI扩展
     * @throws Exception
     */
    public static void init() throws Exception {
        ExtensionLoader.getInstance().loadExtension(RegistryService.class);
    }
}
​

RegistryService接口。服务中心的注册接口,根据使用不同的方案如redis或者zookeeper都得去实现这个接口

服务注册时设置TTL的作用:

每个服务信息都有一个TTL,而serviceMap则是用来存放其中服务的服务信息,服务节点会进行轮询serviceMap找到对应自己的服务信息,并且查看是否服务信息超时,超时则进行续签。而当服务A挂掉的时候,A服务信息就会超时,当B来轮询发现A服务信息超时则会将A服务信息进行删除

image.png

image.png

image.png

package org.kkoneone.rpc.registry;
​
import org.kkoneone.rpc.common.ServiceMeta;
​
import java.io.IOException;
import java.util.List;
​
/**
 * 注册服务接口
 * @Author:kkoneone11
 * @name:RegistryService
 * @Date:2023/12/3 9:51
 */
public interface RegistryService {
​
    /**
     * 服务注册
     * @param serviceMeta
     * @throws Exception
     */
    void register(ServiceMeta serviceMeta) throws Exception;
    /**
     * 服务注销
     * @param serviceMeta
     * @throws Exception
     */
    void unRegister(ServiceMeta serviceMeta) throws Exception;
    /**
     * 获取 serviceName 下的所有服务
     * @param serviceName
     * @return
     */
    List<ServiceMeta> discoveries(String serviceName);
    /**
     * 关闭
     * @throws IOException
     */
    void destroy() throws IOException;
}
​

ServiceBeforeFilterHandler前置拦截器

package org.kkoneone.rpc.protocol.handler.service;
​
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import org.kkoneone.rpc.Filter.FilterConfig;
import org.kkoneone.rpc.Filter.FilterData;
import org.kkoneone.rpc.common.RpcRequest;
import org.kkoneone.rpc.common.RpcResponse;
import org.kkoneone.rpc.common.constants.MsgStatus;
import org.kkoneone.rpc.protocol.MsgHeader;
import org.kkoneone.rpc.protocol.RpcProtocol;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
​
/**
 * 前置拦截器
 * @Author:kkoneone11
 * @name:ServiceBeforeFilterHandler
 * @Date:2023/12/3 14:50
 */
public class ServiceBeforeFilterHandler extends SimpleChannelInboundHandler<RpcProtocol<RpcRequest>> {
    private Logger logger = LoggerFactory.getLogger(ServiceBeforeFilterHandler.class);
​
    @Override
    protected void channelRead0(ChannelHandlerContext ctx, RpcProtocol<RpcRequest> protocol) throws Exception {
        final RpcRequest request = protocol.getBody();
        final FilterData filterData = new FilterData(request);
        RpcResponse response = new RpcResponse();
        MsgHeader header = protocol.getHeader();
​
        //将数据加入过滤链 对数据做一些处理
        try{
            FilterConfig.getServiceBeforeFilterChain().doFilter(filterData);
        }catch (Exception e){
            //若失败则重新构建一个协议
            RpcProtocol<RpcResponse> resProtocol = new RpcProtocol<>();
            //请求头中塞入失败状态
            header.setStatus((byte) MsgStatus.FAILED.ordinal());
            //Response中塞错误信息
            response.setException(e);
            //协议中设置封装好的消息头和消息体
            logger.error("before process request {} error", header.getRequestId(), e);
            resProtocol.setHeader(header);
            resProtocol.setBody(response);
            //写回信息
            ctx.writeAndFlush(resProtocol);
            return;
        }
        ctx.fireChannelRead(protocol);
    }
}
​

消息状态MsgStatus

package org.kkoneone.rpc.common.constants;
​
/**
*@Author:kkoneone11
*@name:MsgStatus
*@Date:2023/12/4  13:51
*/public enum MsgStatus {
    SUCCESS,
    FAILED
}
​

RpcRequestHandler处理消费方发送数据并且调用方法

package org.kkoneone.rpc.protocol.handler.service;
​
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import org.kkoneone.rpc.common.RpcRequest;
import org.kkoneone.rpc.poll.ThreadPollFactory;
import org.kkoneone.rpc.protocol.RpcProtocol;
​
/**
 * 处理消费方发送数据并且调用方法
 * @Author:kkoneone11
 * @name:RpcRequestHandler
 * @Date:2023/12/4 14:50
 */
public class RpcRequestHandler extends SimpleChannelInboundHandler<RpcProtocol<RpcRequest>>{
​
    public RpcRequestHandler() {}
​
    @Override
    protected void channelRead0(ChannelHandlerContext ctx, RpcProtocol<RpcRequest> protocol) {
        ThreadPollFactory.submitRequest(ctx,protocol);
    }
}
​

线程池工厂ThreadPollFactory

package org.kkoneone.rpc.poll;
​
​
import io.netty.channel.ChannelHandlerContext;
import org.kkoneone.rpc.common.RpcRequest;
import org.kkoneone.rpc.common.RpcResponse;
import org.kkoneone.rpc.common.RpcServiceNameBuilder;
import org.kkoneone.rpc.common.constants.MsgStatus;
import org.kkoneone.rpc.common.constants.MsgType;
import org.kkoneone.rpc.protocol.MsgHeader;
import org.kkoneone.rpc.protocol.RpcProtocol;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.cglib.reflect.FastClass;
​
import java.util.Map;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
​
/**
 * 线程池工厂
 * @Author:kkoneone11
 * @name:ThreadPollFactory
 * @Date:2023/12/4 15:03
 */
public class ThreadPollFactory {
​
    private static Logger logger = LoggerFactory.getLogger(ThreadPollFactory.class);
​
    //快请求
    private static ThreadPoolExecutor fastPoll;
​
    //慢请求
    private static ThreadPoolExecutor slowPoll;
​
    //慢请求映射
    private static volatile ConcurrentHashMap<String, AtomicInteger> slowTaskMap = new ConcurrentHashMap<>();
​
    //目前可执行的核数
    private static int corSize = Runtime.getRuntime().availableProcessors();
​
    //缓存服务 该缓存放这里不太好,应该作一个统一 Config 进行管理
    private static Map<String, Object> rpcServiceMap;
​
    //静态代码块初始化数据
    static{
        slowPoll = new ThreadPoolExecutor(corSize / 2, corSize , 60L,
                TimeUnit.SECONDS,
                //线程池的任务队列,用于存放待执行的任务
                new LinkedBlockingDeque<>(2000),
                //线程工厂,用于创建新线程并且设置为守护线程
                r->{
                    Thread thread = new Thread(r);
                    thread.setName("slow poll-"+r.hashCode());
                    thread.setDaemon(true);
                    return thread;
                });
        fastPoll = new ThreadPoolExecutor(corSize, corSize*2, 60L,
                TimeUnit.SECONDS,
                new LinkedBlockingDeque<>(1000),
                r->{
                    Thread thread = new Thread(r);
                    thread.setName("fast poll-"+r.hashCode());
                    thread.setDaemon(true);
                    return thread;
                });
        startClearMonitor();
    }
​
    private ThreadPollFactory(){}
​
    public static void setRpcServiceMap(Map<String, Object> rpcMap){
        rpcServiceMap = rpcMap;
    }
​
    /**
     * 清理慢请求
     */
    private static void startClearMonitor(){
        //创建了一个单线程的定时任务执行器 5分钟后执行,然后每隔5分钟执行一次
        Executors.newSingleThreadScheduledExecutor().scheduleWithFixedDelay(()->{
            slowTaskMap.clear();
        },5,5,TimeUnit.MINUTES);
    }
​
    public static void submitRequest(ChannelHandlerContext ctx, RpcProtocol<RpcRequest> protocol){
        //取出协议体
        final RpcRequest request  = protocol.getBody();
        //拼装key 类名+方法+服务版本
        String key = request.getClassName() + request.getMethodName() + request.getServiceVersion();
        //快请求赋值给poll
        ThreadPoolExecutor poll = fastPoll;
        //看慢请求映射中是否有缓存对应的key且初始值 存在则将表明是个慢请求则poll传给slowPoll
        if(slowTaskMap.containsKey(key) && slowTaskMap.get(key).intValue() >= 10){
            poll = slowPoll;
        }
        //线程池执行任务
        poll.submit(()->{
            //组装一个新的rpc协议
            RpcProtocol<RpcResponse> resProtocol = new RpcProtocol<>();
            //取出协议头
            final MsgHeader header = protocol.getHeader();
            //新建RpcResponse
            RpcResponse response = new RpcResponse();
            long startTime = System.currentTimeMillis();
            //发送请求
            try{
                //处理返回结果
                final Object result = submit(ctx, protocol);
                //设置返回体
                response.setData(result);
                //返回数据类别
                response.setDataClass(result.getClass());
                //返回状态
                header.setStatus((byte) MsgStatus.SUCCESS.ordinal());
            }catch (Exception e){
                header.setStatus((byte) MsgStatus.FAILED.ordinal());
                response.setException(e);
                logger.error("process request {} error", header.getRequestId(), e);
            }finally {
                //计算请求耗费时长 超过1000的则加入慢请求映射中
                long cost = System.currentTimeMillis() - startTime;
                System.out.println("cost time:" + cost);
                if(cost > 1000){
                    final AtomicInteger timeOutCount = slowTaskMap.putIfAbsent(key, new AtomicInteger(1));
                    if (timeOutCount!=null){
                        timeOutCount.incrementAndGet();
                    }
                }
            }
            resProtocol.setHeader(header);
            resProtocol.setBody(response);
            logger.info("执行成功: {},{},{},{}",Thread.currentThread().getName(),request.getClassName(),request.getMethodName(),request.getServiceVersion());
            //将协议写到管道里
            ctx.fireChannelRead(resProtocol);
        });
​
    }
​
    /**
     *
     * @param ctx
     * @param protocol
     * @return
     * @throws Exception
     */
    private static Object submit(ChannelHandlerContext ctx, RpcProtocol<RpcRequest> protocol) throws Exception{
        MsgHeader header = protocol.getHeader();
        header.setMsgType((byte) MsgType.RESPONSE.ordinal());
        final RpcRequest request = protocol.getBody();
        // 执行具体业务
        return handle(request);
    }
​
    /**
     * 调用invoke方法处理请求
     * @param request
     * @return
     * @throws Exception
     */
    private static Object handle(RpcRequest request) throws Exception {
        //组装服务key
        String serviceKey = RpcServiceNameBuilder.buildServiceKey(request.getClassName(), request.getServiceVersion());
        //从缓存中获取服务信息
        Object serviceBean = rpcServiceMap.get(serviceKey);
        if(serviceBean == null){
            throw new RuntimeException(String.format("service not exist: %s:%s", request.getClassName(), request.getMethodName()));
        }
        //获取服务提供方信息并创建
        //通过反射获取类实例
        Class<?> serviceClass = serviceBean.getClass();
        //获取请求名
        String methodName = request.getMethodName();
        //获取请求的参数
        Class<?>[] parameterTypes = request.getParameterTypes();
        Object[] parameters = {request.getData()};
        FastClass fastClass = FastClass.create(serviceClass);
        int methodIndex = fastClass.getIndex(methodName, parameterTypes);
​
        // 调用invoke方法并返回结果
        return fastClass.invoke(methodIndex, serviceBean, parameters);
    }
​
}
​

ServiceAfterFilterHandler后置处理器

package org.kkoneone.rpc.protocol.handler.service;
​
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import org.kkoneone.rpc.Filter.FilterConfig;
import org.kkoneone.rpc.Filter.FilterData;
import org.kkoneone.rpc.Filter.client.ClientLogFilter;
import org.kkoneone.rpc.common.RpcResponse;
import org.kkoneone.rpc.common.constants.MsgStatus;
import org.kkoneone.rpc.protocol.MsgHeader;
import org.kkoneone.rpc.protocol.RpcProtocol;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
​
/**
 * @Author:kkoneone11
 * @name:ServiceAfterFilterHandler
 * @Date:2023/12/5 17:28
 */
public class ServiceAfterFilterHandler extends SimpleChannelInboundHandler<RpcProtocol<RpcResponse>> {
​
    private Logger logger = LoggerFactory.getLogger(ClientLogFilter.class);
​
    @Override
    protected void channelRead0(ChannelHandlerContext ctx, RpcProtocol<RpcResponse> protocol) {
        final FilterData filterData = new FilterData();
        filterData.setData(protocol.getBody());
        RpcResponse response = new RpcResponse();
        MsgHeader header = protocol.getHeader();
        try {
            FilterConfig.getServiceAfterFilterChain().doFilter(filterData);
        } catch (Exception e) {
            header.setStatus((byte) MsgStatus.FAILED.ordinal());
            response.setException(e);
            logger.error("after process request {} error", header.getRequestId(), e);
        }
        ctx.writeAndFlush(protocol);
    }
}
​

注册中心具体实现

RedisRegistry redis注册中心

注意:StringRedisTemplate是SpringBoot提供的高级抽象,它简化了与Redis的交互。而Jedis是一个Java Redis客户端库,它提供了直接与Redis进行交互的底层API。通常情况下,推荐使用StringRedisTemplate来与Redis交互,因为它提供了更高级的抽象和更方便的操作。

ZookeeperRegistry zk注册中心

服务发现类

核心类

服务注册类ProviderPostProcessor

实现了的接口及对应的方法:

  • BeanPostProcessor:对bean进行额外的操作

实现了对应的postProcessBeforeInitialization()接口方法对服务进行了注册

  • EnvironmentAware:XXXAware后缀的通常是可以拿到前面的前缀命名的类然后进行一些操作。而此处的EnvironmentAware就是用来读取配置文件中的一些信息然后给到rpcProperties供全局使用

其中setEnvironment()方法是用来读取配置文件中的配置信息的。其中封装了一个PropertiesUtils工具类调用其中的init()方法来将对应的配置信息封装进RpcProperties。

image-20231202093440394转存失败,建议直接上传图片文件

image.png 其中特别注意的是 PropertyUtil.handle()函数,可以通过拦截器拦截对应的类进行塞入额外拓展的参数

image.png

image.png

    ​
    import io.netty.bootstrap.ServerBootstrap;
    import io.netty.channel.ChannelFuture;
    import io.netty.channel.ChannelInitializer;
    import io.netty.channel.ChannelOption;
    import io.netty.channel.nio.NioEventLoopGroup;
    import io.netty.channel.socket.SocketChannel;
    import io.netty.channel.socket.nio.NioServerSocketChannel;
    import org.kkoneone.rpc.Filter.FilterConfig;
    import org.kkoneone.rpc.annotation.RpcService;
    import org.kkoneone.rpc.common.RpcServiceNameBuilder;
    import org.kkoneone.rpc.common.ServiceMeta;
    import org.kkoneone.rpc.config.RpcProperties;
    import org.kkoneone.rpc.poll.ThreadPollFactory;
    import org.kkoneone.rpc.protocol.codec.RpcDecoder;
    import org.kkoneone.rpc.protocol.codec.RpcEncoder;
    import org.kkoneone.rpc.protocol.handler.service.RpcRequestHandler;
    import org.kkoneone.rpc.protocol.handler.service.ServiceAfterFilterHandler;
    import org.kkoneone.rpc.protocol.handler.service.ServiceBeforeFilterHandler;
    import org.kkoneone.rpc.protocol.serialization.SerializationFactory;
    import org.kkoneone.rpc.registry.RegistryFactory;
    import org.kkoneone.rpc.registry.RegistryService;
    import org.kkoneone.rpc.router.LoadBalancerFactory;
    import org.kkoneone.rpc.utils.PropertiesUtils;
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    import org.springframework.beans.BeansException;
    import org.springframework.beans.factory.InitializingBean;
    import org.springframework.beans.factory.config.BeanPostProcessor;
    import org.springframework.context.EnvironmentAware;
    import org.springframework.core.env.Environment;
    ​
    import java.util.HashMap;
    import java.util.Map;
    ​
    /**
     * 服务提供方后置处理器
     * @Author:kkoneone11
     * @name:ProviderPostProcessor
     * @Date:2023/12/2 12:45
     */
    public class ProviderPostProcessor implements InitializingBean, BeanPostProcessor, EnvironmentAware {
        private Logger logger = LoggerFactory.getLogger(ProviderPostProcessor.class);
        RpcProperties rpcProperties;
        // 此处在linux环境下改为0.0.0.0
        private static String serverAddress = "127.0.0.1";
        private final Map<String, Object> rpcServiceMap = new HashMap<>();
    ​
    ​
        /**
         * 启动RPC服务
         * @throws InterruptedException
         */
        private void startRpcServer() throws InterruptedException {
            Integer serverPort = rpcProperties.getPort();
            //设置上级事件循环组和下级事件循环组
            NioEventLoopGroup boss = new NioEventLoopGroup();
            NioEventLoopGroup worker = new NioEventLoopGroup();
            try{
                //创建客户端
                ServerBootstrap bootstrap = new ServerBootstrap();
                //传入必要参数
                bootstrap.group(boss, worker)
                        .option(ChannelOption.SO_KEEPALIVE, true)
                        .channel(NioServerSocketChannel.class)
                        .childHandler(new ChannelInitializer<SocketChannel>() {
                            @Override
                            protected void initChannel(SocketChannel socketChannel) throws Exception {
                                socketChannel.pipeline()
                                        .addLast(new RpcEncoder())
                                        .addLast(new RpcDecoder())
                                        .addLast(new ServiceBeforeFilterHandler())
                                        .addLast(new RpcRequestHandler())
                                        .addLast(new ServiceAfterFilterHandler());
                            }
                        })
                        .childOption(ChannelOption.SO_KEEPALIVE, true);
                //开启管道异步获取结果
                ChannelFuture channelFuture = bootstrap.bind(this.serverAddress, serverPort).sync();
                logger.info("server addr {} started on port {}", this.serverAddress, serverPort);
                //阻塞当前线程并保持应用程序运行,直到服务器通道关闭
                channelFuture.channel().closeFuture().sync();
                //添加一个钩子函数
                Runtime.getRuntime().addShutdownHook(new Thread(() ->
                {
                    logger.info("ShutdownHook execute start...");
                    logger.info("Netty NioEventLoopGroup shutdownGracefully...");
                    logger.info("Netty NioEventLoopGroup shutdownGracefully2...");
                    boss.shutdownGracefully();
                    worker.shutdownGracefully();
                    logger.info("ShutdownHook execute end...");
                }, "Allen-thread"));
            }finally {
                boss.shutdownGracefully();
                worker.shutdownGracefully();
            }
        }
    ​
    ​
        @Override
        public void afterPropertiesSet() throws Exception {
            Thread t = new Thread(() -> {
                try {
                    startRpcServer();
                } catch (Exception e) {
                    logger.error("start rpc server error.", e);
                }
            });
            t.setDaemon(true);
            t.start();
            SerializationFactory.init();
            RegistryFactory.init();
            LoadBalancerFactory.init();
            FilterConfig.initServiceFilter();
            ThreadPollFactory.setRpcServiceMap(rpcServiceMap);
        }
    ​
    ​
        /**
         * 服务注册
         * @param bean
         * @param beanName
         * @return
         * @throws BeansException
         */
        @Override
        public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
            //通过反射获取bean的对应类
            Class<?> beanClass = bean.getClass();
            //找到bean上带有 RpcService 注解标识的类(即需要注册的类)
            RpcService rpcService = beanClass.getAnnotation(RpcService.class);
            if(rpcService!=null){
                //可能会有多个接口,默认选择第一个接口的接口名用作服务调用方和服务提供方找到对应接口调用
                String serviceName = beanClass.getInterfaces()[0].getName();
                if(!rpcService.serviceInterface().equals(void.class)){
                    //说明rpcService对象定义了特定的服务接口则不使用第一个接口
                    serviceName = rpcService.serviceInterface().getName();
                }
                String serviceVersion = rpcService.serviceVersion();
                try{
                    // 服务注册
                    Integer servicePort = rpcProperties.getPort();
                    // 从配置文件获取注册中心来创建一个注册中心实例 ioc
                    RegistryService registryService = RegistryFactory.get(rpcProperties.getRegisterType());
                    ServiceMeta serviceMeta = new ServiceMeta();
                    // 服务提供方地址 将服务注册到注册中心上
                       //服务端口
                    serviceMeta.setServicePort(servicePort);
                       //服务地址
                    serviceMeta.setServiceAddr("127.0.0.1");
                       //服务版本
                    serviceMeta.setServiceVersion(serviceVersion);
                       //服务名字
                    serviceMeta.setServiceName(serviceName);
                    registryService.register(serviceMeta);
                    // 缓存
                    rpcServiceMap.put(RpcServiceNameBuilder.buildServiceKey(serviceMeta.getServiceName(),serviceMeta.getServiceVersion()), bean);
                    logger.info("register server {} version {}",serviceName,serviceVersion);
                }catch (Exception e){
                    logger.error("failed to register service {}",  serviceVersion, e);
                }
            }
            return bean;
        }
    ​
    ​
    ​
        /**
         * 设置配置文件
         * @param environment
         */
        @Override
        public void setEnvironment(Environment environment) {
            RpcProperties properties = RpcProperties.getInstance();
            //通过init方法将environment中的参数配置到RpcProperties中方便全局使用
            PropertiesUtils.init(properties,environment);
            rpcProperties = properties;
        }
    }
    ​

服务发现类ConsumerPostProcessor

  • BeanPostProcessor:同样实现了这个接口重写了里面的postProcessAfterInitialization()方法用作服务发现并注册为代理对象

    同样先拿到对应的bean(需要使用对应服务的服务消费方类如下方的TestService)通过反射获取所有字段,然后判断其上方是否有@RpcReference 注解(如下方的testService字段),有的则用Proxy创建一个对象并用RpcInvokerProxy给该对象配置对应的属性(如负载均衡策略等),

image.png

image.png

最后将该对象设置给该字段,就像"TestService testService = service"这样去赋值去进行注入

image.png

package org.kkoneone.rpc.consumer;
​
import org.kkoneone.rpc.Filter.FilterConfig;
import org.kkoneone.rpc.Filter.client.ClientLogFilter;
import org.kkoneone.rpc.annotation.RpcReference;
import org.kkoneone.rpc.config.RpcProperties;
import org.kkoneone.rpc.protocol.serialization.SerializationFactory;
import org.kkoneone.rpc.registry.RegistryFactory;
import org.kkoneone.rpc.router.LoadBalancerFactory;
import org.kkoneone.rpc.utils.PropertiesUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.context.EnvironmentAware;
import org.springframework.core.env.Environment;
​
import java.lang.reflect.Field;
import java.lang.reflect.Proxy;
​
/**
 * @Author:kkoneone11
 * @name:ConsumerPostProcessor
 * @Date:2023/12/5 23:58
 */
public class ConsumerPostProcessor implements BeanPostProcessor, EnvironmentAware, InitializingBean {
​
    private Logger logger = LoggerFactory.getLogger(ClientLogFilter.class);
    RpcProperties rpcProperties;
    /**
     * 初始化一些bean
     * @throws Exception
     */
    @Override
    public void afterPropertiesSet() throws Exception {
        SerializationFactory.init();
        RegistryFactory.init();
        LoadBalancerFactory.init();
        FilterConfig.initClientFilter();
    }
​
    /**
     * 从配置文件中读取配置
     * @param environment
     */
    @Override
    public void setEnvironment(Environment environment) {
        RpcProperties properties = RpcProperties.getInstance();
        PropertiesUtils.init(properties,environment);
        rpcProperties = properties;
        logger.info("读取配置文件成功");
    }
​
​
    /**
     * 服务发现 代理层注入
     * @param bean
     * @param beanName
     * @return
     * @throws BeansException
     */
    @Override
    public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
        //通过反射获取所有字段
        final Field[] fields = bean.getClass().getDeclaredFields();
        //查找字段上标有@RpcReference注解的字段
        for(Field field : fields){
            if(field.isAnnotationPresent(RpcReference.class)){
                RpcReference rpcReference = field.getAnnotation(RpcReference.class);
                //获取字段的类类型
                Class<?> aClass = field.getType();
                //设置字段为可访问
                field.setAccessible(true);
                Object object = null;
                try{
                    //在运行时期创建代理对象
                    object = Proxy.newProxyInstance(
                            aClass.getClassLoader(),
                            new Class<?>[]{aClass},
                            new RpcInvokerProxy(rpcReference.serviceVersion(),rpcReference.timeout(),rpcReference.faultTolerant(),
                                    rpcReference.loadBalancer(),rpcReference.retryCount()));
                }catch (Exception e){
                    e.printStackTrace();
                }
                try{
                    //代理对象创建成功则赋值给@RpcReference进行动态代理
                    field.set(bean,object);
                    field.setAccessible(false);
                    logger.info(beanName + " field:" + field.getName() + "注入成功");
                }catch (Exception e){
                    e.printStackTrace();
                    logger.info(beanName + " field:" + field.getName() + "注入失败");
                }
            }
        }
        return bean;
    }
​
​
}
​