原来手写一个Spring容器这么简单!

192 阅读4分钟

前言

实现的效果就是我们通过指定的beanname调用ApplicationContext#getBean()方法就能获取到自定义容器中注册的bean。自己动手实现一个简单的springApplicationContext容器,自定义的BeanDefination,自定义的@Autowire、@Component、@ComponentScan注解,还有回调方法接口BeanNameAware、前置和后置处理接口BeanPostProcessor, 初始化InitializingBean接口,其中在容器中还实现了基于jdk动态代理(基于接口的)的AOP。

场景

怎么实现下面这段代码的容器功能呢?通过配置类初始化容器ApplicationContext,然后通过ApplicationContext#getBean()方法获取bean。

测试类

public class Test {

    public static void main(String[] args) throws ClassNotFoundException {
        HeavenApplicationContext heavenApplicationContext = new HeavenApplicationContext(AppConfig.class);

        // 测试依赖注入
        ServiceInterface bean = (ServiceInterface) heavenApplicationContext.getBean("userService");
        bean.test();

    }

}

配置类

@ComponentScan("com.heaven.service")  
public class AppConfig {  
}

注入服务类

@Component("userService")
@Scope
public class UserService implements BeanNameAware, InitializingBean, ServiceInterface {
    @Autowire
    private OrderService orderService;
    private String beanName;
    
    public void test() {
        System.out.println(orderService);
    }
    
    @Override
    public void setBeanName(String beanName) {
        this.beanName = beanName;
    }

    @Override
    public void afterPropertiesSet() {
        System.out.println("初始化: " + this.beanName);
    }
}

属性注入类

@Component
public class OrderService {
}

处理类

public class HeavenBeanPostProcessor implements BeanPostProcessor {
    @Override
    public void postProcessBeforeInitialization(String beanName, Object instance) {
        if ("userService".equals(beanName)) {

            System.out.println("postProcessBeforeInitialization");
        }
    }

    @Override
    public void postProcessAfterInitialization(String beanName, Object instance) {
        if ("userService".equals(beanName)) {
            System.out.println("postProcessAfterInitialzation");
        }
    }
}

实现原理

自实现容器

HeavenApplicationContext

public class HeavenApplicationContext {
    private ConcurrentHashMap<String, BeanDefination> beanDefinationMap = new ConcurrentHashMap<>();
    private ConcurrentHashMap<String, Object> singletonBeanMap = new ConcurrentHashMap<>();
    private List<BeanPostProcessor> beanPostProcessors = new ArrayList<>();
    public HeavenApplicationContext(Class<?> appConfigClass) throws ClassNotFoundException {
        // 首先获取扫描路径
        ComponentScan componentScanAnnotation = appConfigClass.getAnnotation(ComponentScan.class);
        String scanBeanPath = componentScanAnnotation.value();
        System.out.println(scanBeanPath);

        // 获取路径下的路径
        ClassLoader classLoader = appConfigClass.getClassLoader();
        URL resource = classLoader.getResource(scanBeanPath.replace(".", "/"));
        System.out.println(resource);

        // 判断该路径下所有的class类
        File file = new File(resource.getFile());
        if (file.isDirectory()) {
            File[] files = file.listFiles();
            for (File file1 : files) {
                System.out.println("配置路径下的类");
                System.out.println(file1.getAbsolutePath());
                try {
                    System.out.println(file1.getCanonicalFile());
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
                String fileName = file1.getAbsolutePath();
                // 所有的class类
                if (fileName.endsWith(".class")) {

                    // 获取所有需要注入的类, 即加了@Component注解的类

                    String className = fileName.substring(fileName.indexOf("com"), fileName.indexOf(".class")).replace("\", ".");
                    Class<?> aClass = classLoader.loadClass(className);
                    if (aClass.isAnnotationPresent(Component.class)) {
                        if (BeanPostProcessor.class.isAssignableFrom(aClass)) {
                            try {
                                Object o = aClass.newInstance();
                                beanPostProcessors.add((BeanPostProcessor) o);
                            } catch (InstantiationException e) {
                                throw new RuntimeException(e);
                            } catch (IllegalAccessException e) {
                                throw new RuntimeException(e);
                            }
                        }
                        System.out.println("需要注入的类: " + aClass.getName());
                        BeanDefination beanDefination = new BeanDefination();
                        if (aClass.isAnnotationPresent(Scope.class)) {
                            Scope annotation = aClass.getAnnotation(Scope.class);
                            beanDefination.setScope(annotation.value());
                        } else {
                            beanDefination.setScope("singleton");
                        }
                        String beanName = aClass.getAnnotation(Component.class).value();
                        if ("".equals(beanName)) {
                            beanName = Introspector.decapitalize(aClass.getSimpleName());
                        }
                        beanDefination.setaClass(aClass);
                        beanDefinationMap.put(beanName, beanDefination);

                    }
                }
            }
        }

        // 创建单例bean
        for (String beanName : beanDefinationMap.keySet()) {
            BeanDefination beanDefination = beanDefinationMap.get(beanName);
            if ("singleton".equals(beanDefination.getScope())) {
                Object singletonBean = createBean(beanName, beanDefination);
                singletonBeanMap.put(beanName, singletonBean);
            }
        }
    }

    private Object createBean(String beanName, BeanDefination beanDefination) {
        Class<?> aClass = beanDefination.getaClass();

        Object o = null;
        try {
            o = aClass.newInstance();

            for (Field field : aClass.getDeclaredFields()) {
                if (field.isAnnotationPresent(Autowire.class)) {
                    field.setAccessible(true);
                    field.set(o, getBean(field.getName()));
                }
            }
        } catch (InstantiationException e) {
            throw new RuntimeException(e);
        } catch (IllegalAccessException e) {
            throw new RuntimeException(e);
        }

        // 回调机制,设置bean名字
        if (o instanceof BeanNameAware) {
            ((BeanNameAware) o).setBeanName(beanName);
        }

        // 初始化前
        for (BeanPostProcessor beanPostProcessor : beanPostProcessors) {
            beanPostProcessor.postProcessBeforeInitialization(beanName, o);
        }

        // 初始化
        if (o instanceof InitializingBean) {
            ((InitializingBean) o).afterPropertiesSet();
        }

        for (BeanPostProcessor beanPostProcessor : beanPostProcessors) {
            beanPostProcessor.postProcessAfterInitialization(beanName, o);
        }
        // AOP
        Object proxyObject = o;
        if (o instanceof ServiceInterface) {
            Object finalO = o;
            proxyObject = Proxy.newProxyInstance(aClass.getClassLoader(), o.getClass().getInterfaces(),
                    new InvocationHandler() {
                        @Override
                        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
                            System.out.println("切面前逻辑");
                            method.invoke(finalO);
                            System.out.println("切面后逻辑");
                            return null;
                        }
                    });
        }
        return proxyObject;
    }

    public Object getBean(String userService) {
        BeanDefination beanDefination = beanDefinationMap.get(userService);
        if (!Objects.nonNull(beanDefination)) {
            throw new NullPointerException();
        }
        String scope = beanDefination.getScope();
        if ("singleton".equals(scope)) {
            Object o = singletonBeanMap.get(userService);
            if (!Objects.nonNull(o)) {
                Object bean = createBean(userService, beanDefination);
                singletonBeanMap.put(userService, bean);
                return bean;
            }
            return o;
        }
        return createBean(userService, beanDefination);
    }
}

自定义注解

@Scope

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface Scope {
    String value() default "singleton";
}

@Autowire

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.FIELD)

public @interface Autowire {
}

@Component

@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public @interface Component {
    String value() default "";
}

@ComponentScan

@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public @interface ComponentScan {
    String value() default "";
}

Bean定义

BeanDefination

public class BeanDefination {
    private Class<?> aClass;
    private String scope;

    public Class<?> getaClass() {
        return aClass;
    }

    public void setaClass(Class<?> aClass) {
        this.aClass = aClass;
    }

    public void setScope(String scope) {
        this.scope = scope;
    }

    public String getScope() {
        return scope;
    }
}

回调

BeanNameAware

public interface BeanNameAware {
    public void setBeanName(String beanName);
}

处理

BeanPostProcessor

public interface BeanPostProcessor {
    public void postProcessBeforeInitialization(String beanName, Object instance);
    public void postProcessAfterInitialization(String beanName, Object instance);
}

初始化

InitailizingBean

public interface InitializingBean {
    public void afterPropertiesSet();
}

AOP代理接口

ServiceInterface

public interface ServiceInterface {
    int a = 0;
    void test();
}

总结

最后实现的简单spring容器逻辑如下:

  1. 首先自定义注解ComponentScan。注解我们的配置文件,指定容器扫描路径。

  2. 通过配置文件来初始化自己的ApplicationContext。在初始化配置文件的过程中,我们获取配置文件中指定的扫描路径下所有的class类(其中包含我们自定义注解Component注解的service类),然后解析所有带有Component注解的类为BeanDefination,存储到容器中的beanDefinationMap,其中会扫描所有的BeanPostProcessor,存储到容器中的beanPostProcessorsList中,初始化最后,如果类的scope是单例,则直接创建,存储到容器的singletonBeanMap中

  3. 实现getBean()方法。首先如果获取的bean是单例,则直接在本地的singletonBeanMap中直接获取返回。如果不是,则进入createBean()方法创建bean,然后返回。

  4. 实现createBean()方法。

    • 初始化依赖,如果该类有点属性带有@Autowire注解,那么进行属性注入,这里可能会存在循环依赖的问题(没有解决这个点)。
    • 根据回调机制(自己实现的ware接口)对bean进行赋值。
    • 用初始化前逻辑,即我们实现的BeanPostProcessors,这里我们会执行我们初始化容器时创建的所有BeanPostProcessors,然后执行对应的BeanPostProcessor#postProcessBeforeInitialization()方法。
    • 调用初始化逻辑,即我们实现的InitializingBean。
    • 然后就是BeanPostProcessor#postProcessAfterInitialization()逻辑了
    • 最后就是AOP切面动态代理了,返回对应的代理类。

扩展

  1. 阅读Application#getBean()源码,探究源码的具体流程。