手把手教你写一个简单的IOC容器和DI

1,263 阅读9分钟

在Spring中IOC是个绝佳的解耦合手段,为了更好的理解我就动手自己写了一个

预备知识:

注解,反射,集合类,lambda表达式,流式API

IOC

如何把一个类注册进去呢?首先我们要让容器“发现”它,所以使用注解,声明它应当加入容器

其中的value即对应的是Spring中的Bean name

@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE})
public @interface Part {
    String value() default "";
}

对于工厂bean

//用于标注工厂类
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
@Part
public @interface FactoryBean {
}
//用于标注生产函数
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface Produce {
    String value() default "";
}

扫描包生成类的工具

当然,有人会说hutool的ClassScaner很好用,但是这里为了加深理解,我就自己写一个

思路就是利用文件名利用Class.forName()得到类的反射再生成实例

public static List<Object> find(String packName, ClassFilter classFilter) throws IOException {
        //获取当前路径
        Enumeration<URL> entity = Thread.currentThread().getContextClassLoader().getResources(packName);
        HashSet<String> classPaths = new HashSet<>();
        ArrayList<Object> classes = new ArrayList<>();
        //拿到处理后的路径,处理前为/..../target/classes
        //处理后为/..../target/classes
        if (entity.hasMoreElements()) {
            String path = entity.nextElement().getPath().substring(1);
            classPaths.add(path);
        }
        //这里跳转到我写的一个把路径下的.class文件生成为类名的方法,后面会讲述
        //set的元素为类名 比如Entity.Student
        Set<String> set = loadClassName(classPaths);
        for (String s : set) {
            try {
                Class<?> c = Class.forName(s);
                //利用过滤器判断需不需要生成实例
                if (classFilter.test(c)){
                    //这里为了简单使用无参构造器
                    Constructor<?> constructor = c.getConstructor();
                    constructor.setAccessible(true);
                    //将生成的实例加入返回的list集合中
                    classes.add(constructor.newInstance());
                }
            }catch (ClassNotFoundException| InstantiationException | IllegalAccessException| InvocationTargetException e) {
                throw new RuntimeException(e);
            }catch (NoSuchMethodException e){
                System.err.println(e.getMessage());
            }
        }
        return classes;
    }

到来其中的一个核心函数loadClassName

/**
     * @param classPaths 路径名集合
     * @return 类名的集合
     */
    private static Set<String> loadClassName(HashSet<String> classPaths){
        Queue<File> queue = new LinkedList<>();
        HashSet<String> classNames = new HashSet<>();
        //对每一个路径得到对应所有以.class结尾的文件
        classPaths.forEach(p -> {
            //迭代的方法,树的层次遍历
            queue.offer(new File(p));
            while (!queue.isEmpty()){
                File file = queue.poll();
                if (file.isDirectory()) {
                    File[] files = file.listFiles();
                    for (File file1 : files) {
                        queue.offer(file1);
                    }
                }else if(file.getName().endsWith(".class")){
                    //对文件名处理得到类名
                    // ..../target/classes处理完为  \....\target\classes
                    String replace = p.replace("/", "\\");
                    //对于每个.class文件都是以....\target\classes开头,去掉开头,去掉后缀就是类名了
                   String className = file.getPath()
                            .replace(replace, "")
                            .replace(".class", "").replace("\\", ".");
                    classNames.add(className);
                }
            }
        });
        return classNames;
    }

好了,现在就可以扫描包了

上面我也提到了不是所有的类都必须放到容器中,现在让我们看看这个 ClassFilter 过滤器是什么东西吧

@FunctionalInterface
public interface ClassFilter{
    boolean test(Class c);
}

是个函数式接口,这就意味着使用lambda表达式会很方便

通过这个接口我们就很容易地构造这么一个函数帮我们把所有有@Part注解的类生成好

public static<T> List<Object> findByAnnotation(String packName, Class<T> annotation) throws IOException{
        if (!annotation.isAnnotation()) {
            throw new RuntimeException("it not an annotation"+annotation.getTypeName());
        }
        ClassFilter classFilter =(c) -> c.getAnnotation(annotation) != null;
        return find(packName, classFilter);
    }

IOC容器

上面的准备工作做的差不多了

该动手写IOC容器了

思考一下在Spring中我们很容易通过bean name得到java bean,所以使用一个Map<String,Object>可以模拟一下。

这里我们在IOCContainer中添加一个变量

private Map<String,Object> context;

构造函数

public IOCContainer(String packName){
        try {
            initPart(packName);
            initFactory();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
    public IOCContainer(){
        this("");
    }

初始化含有@Part注解类,类似于@Component

/** 将所有有@Part注解(包括有被@Part注解的注解)
     *  比如说@Part,@BeanFactory,都会被读入
     *  {
     *      @Retention(RetentionPolicy.RUNTIME)
     *      @Target(ElementType.TYPE)
     *      @Part
     *   public @interface FactoryBean {
     * }
     *  }
     * @param packName 路径名在ClassScannerUtil中的函数要使用
     * @throws IOException
     * @author dreamlike_ocean
     */
public void initPart(String packName) throws IOException {
        //做一个bean name 的映射。如果@Part注解中的值不为空则使用value的值做bean name
        //如果为空就用这个 java bean的类名做bean name
        Function<Object,String> keyMapper = (o) -> {
            Class<?> aClass = o.getClass();
            Part part = aClass.getAnnotation(Part.class);
            if (part == null ||part.value().isBlank()) {
                return o.getClass().getTypeName();
            }
            return part.value();
        };
        context = new HashMap<String,Object>();
        //获取所有添加@Part注解的类实例
        List<Object> objectList = ClassScannerUtil.find(packName,(c) -> isContainPart(c,Part.class) );
        //List<Object> objectList = ClassScannerUtil.findByAnnotation(packName, Part.class);
        //先把自己注入进去
        context.put("IOCContainer", this);
        for (Object o : objectList) {
            //利用上面写好的映射函数接口 获取bean name
            String beanName = keyMapper.apply(o);
            //bean name冲突情况,直接报错
            if (context.containsKey(beanName)) {
                String msg = new StringBuilder().append("duplicate bean name: ")
                        .append(beanName)
                        .append("in")
                        .append(o.getClass())
                        .append(" and ")
                        .append(context.get(beanName).getClass()).toString();
                throw new RuntimeException(msg);
            }
            //加入容器
            context.put(beanName, o);
        }

        //帮助垃圾回收,这个复杂度为O(n),理论上objectList = null也能帮助回收
        objectList.clear();
    }

那么完成这个功能一个很核心的函数就是 isContainPart()

 /**
     * 迭代查询是否有target注解在类上,包含注解中的注解(类似于@FactoryBean 中包含 @Part )
     * @param c 要判断是否包含 @param target 注解的类
     * @param target
     * @return 是否包含
     */
    private boolean isContainPart(Class c,Class<? extends Annotation> target){
        Queue<Annotation> queue = new LinkedList<>(Arrays.asList(c.getDeclaredAnnotations()));
        while (!queue.isEmpty()) {
            Annotation poll = queue.poll();
            //判断是否相同
            if (poll.annotationType().isAssignableFrom(target)){
                return true;
            }
            Annotation[] annotations = poll.annotationType().getDeclaredAnnotations();
            for (Annotation annotation : annotations) {
                //如果像是@Documented就会导致无线死循环而且队列大小会不断变大
                //所以需要判断是否重复
                if (!isRepeat(annotation.annotationType())){
                    queue.offer(annotation);
                }
            }
        }
        return false;
    }

    /**
     * 判断一个注解上是不是被注解了自己
     * 举个例子
     * {
     *     @Documented
     *     @Retention(RetentionPolicy.RUNTIME)
     *     @Target(ElementType.ANNOTATION_TYPE)
     * public @interface Documented {
     * }
     * }
     * 此时isRepeat就会返回ture
     * @param aClass 需要被判断的注解的反射
     * @return
     */
    private boolean isRepeat(Class<? extends Annotation> aClass){
        //基础思想就是如果自己注解自己就肯定是要重复的
        Annotation[] declaredAnnotations = aClass.getDeclaredAnnotations();
        for (Annotation declaredAnnotation : declaredAnnotations) {
            if (declaredAnnotation.annotationType().isAssignableFrom(aClass))
                return true;
        }
        return false;
    }

初始化含有@BeanFactory注解类,类似于@Configuration

 @SneakyThrows
    public void initFactory(){
        //使用一个新的map避免一边遍历,一边删除导致的ConcurrentModificationException
        HashMap<String, Object> map = new HashMap<>();
        Collection<Object> beans = getBeans();
        for (Object o : beans) {
            //先把无标注的筛掉
            if (o.getClass().getAnnotation(FactoryBean.class) == null) {
                continue;
            }
            Method[] methods = o.getClass().getDeclaredMethods();
            for (Method method : methods) {
                Produce produce = method.getAnnotation(Produce.class);
                //把无标注的方法筛掉
                if (produce != null) {
                    method.setAccessible(true);
                    //老规矩,提供名称的就使用提供的,没有就使用方法名
                    Object result = invokeMethod(o, method);
                    if (produce.value().isBlank()){
                    map.put(method.getName(), result);
                    }else {
                        map.put(produce.value(),result);
                    }
                }
            }
        }
        //不建议如果出现重名则可能有bean没有添加进容器
        //context.putAll(map);
        Set<Map.Entry<String, Object>> entries = map.entrySet();
        //使用遍历set的方法避免重名,及时发现重名问题
        for (Map.Entry<String, Object> entry : entries) {
            String beanName = entry.getKey();
            Object value = entry.getValue();
            if (context.containsKey(beanName)) {
                String msg = new StringBuilder().append("duplicate bean name: ")
                        .append(beanName)
                        .append("in")
                        .append(value.getClass())
                        .append(" and ")
                        .append(context.get(beanName).getClass()).toString();
                throw new RuntimeException(msg);
            }
            context.put(beanName, value);
        }
    }

    /**
     * 通过从容器中拿出对应类型的参数,类似于setter注入
     * @param o 需要被调用方法的实例
     * @param method 需要被调用的方法
     * @return 方法返回值
     * @throws IllegalAccessException
     * @throws InvocationTargetException
     * @author dreamlike_ocean
     */
    private Object invokeMethod(Object o, Method method) throws IllegalAccessException, InvocationTargetException {
        //获取参数列表
        Class<?>[] parameterTypes = method.getParameterTypes();
        method.setAccessible(true);
        int i = method.getParameterCount();
        //为储存实参做准备
        Object[] param = new Object[i];
        //变量重用,现在它代表当前下标了
        i = 0;
        for (Class<?> parameterType : parameterTypes) {
            List<?> list = getBeanByType(parameterType);
            if (list.size() == 0) {
                throw new RuntimeException("not find " + parameterType + "。method :" + method + "class:" + o.getClass());
            }
            if (list.size() != 1) {
                throw new RuntimeException("too many");
            }
            //暂时存储实参
            param[i++] = list.get(0);
        }
        //调用对应实例的函数
        return method.invoke(o, param);
    }

对外暴露的获取Bean的api

    /**
     * 
     * @param beanName
     * @return 记得判断空指针
     * @author dreamlike_ocean
     */
    public Optional<Object> getBean(String beanName){
        return Optional.ofNullable(context.get(beanName));
    }

    /**
     * 
     * @param beanName
     * @param aclass
     * @param <T> 需要返回的类型,类型强转
     * @exception ClassCastException 类型强转可能导致无法转化的异常          
     * @return @author dreamlike_ocean
     */
    public<T> Optional<T> getBean(String beanName,Class<T> aclass){
        return Optional.ofNullable((T)context.get(beanName));
    }

    /**
     *
     * @param interfaceType
     * @param <T>
     * @return 所有继承这个接口的集合
     * @author dreamlike_ocean
     */
    public<T> List<T> getBeanByInterfaceType(Class<T> interfaceType){
        if (!interfaceType.isInterface()) {
            throw new RuntimeException("it is not an interface type:"+interfaceType.getTypeName());
        }
        return context.values().stream()
                .filter(o -> interfaceType.isAssignableFrom(o.getClass()))
                .map(o -> (T)o)
                .collect(Collectors.toList());
    }

    /**
     * 
     * @param type
     * @param <T>
     * @return 所有这个类型的集合
     * @author dreamlike_ocean
     */
    
    public<T> List<T> getBeanByType(Class<T> type){
        return context.values().stream()
                .filter(o -> type.isAssignableFrom(o.getClass()))
                .map(o -> (T)o)
                .collect(Collectors.toList());
    }

    /**
     * 
     * @return 获取所有值
     * @author dreamlike_ocean 
     */
    public Collection<Object> getBeans(){
        return context.values();
    }

    /**
     * 
     * @return 获取容器
     * @author dreamlike_ocean
     */
    public Map<String,Object> getContext(){
        return context;
    }

IOC构造函数

好了,现在基本加载bean功能已经完善,还记得上面一开始的构造函数吗?

不记得也没关系,我再来写一遍

这样就能理解为什么我要先扫描@Part注解的类了

因为@BeanFactory需要@Part做原料生产新的@Part注解类

public IOCContainer(String packName){
        try {
            initPart(packName);
            initFactory();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
    public IOCContainer(){
        this("");
    }

DI

上面我们获取的都是利用无参的构造函数得到的java bean,这和想的差的有点远,我想要的是一幅画,他却给了我一张白纸。这怎么能行!DI模块上,给他整个活!

为了区别通过类型注入还是名称注入,我写了两个注解用于区分

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.FIELD)
public @interface InjectByName {
    String value();
}
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD,ElementType.METHOD})
public @interface InjectByType {

}

首先DI先必须知道到底对哪个容器注入,所以通过构造函数传入一个

 private IOCContainer iocContainer;
    public DI(IOCContainer iocContainer) {
        Objects.requireNonNull(iocContainer);
        this.iocContainer = iocContainer;
    }

先是对字段的按类型注入

/**
     * 
     * @param o 需要被注入的类
     * @author dreamlike_ocean          
     */

    private void InjectFieldByType(Object o){
        try {
            //获取内部所有字段
            Field[] declaredFields = o.getClass().getDeclaredFields();
            for (Field field : declaredFields) {
                //判断当前字段是否有注解标识
                if (field.getAnnotation(InjectByType.class) != null) {
                    //防止因为private而抛出异常
                    field.setAccessible(true);
                    List list = iocContainer.getBeanByType(field.getType());
                    //如果找不到,那么注入失败
                    //这里我选择抛出异常,也可给他赋值为null
                    if(list.size() == 0){
                        throw new RuntimeException("not find "+field.getType());
                    }
                    //多于一个也注入失败,和Spring一致
                    if (list.size()!=1){
                        throw new RuntimeException("too many");
                    }
                    //正常注入
                    field.set(o, list.get(0));
                }
            }
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        }
    }

对字段按名称注入

 /**
     *
     * @param o 需要被注入的类
     * @author dreamlike_ocean
     */
    private void InjectFieldByName(Object o){
        try {
            Field[] declaredFields = o.getClass().getDeclaredFields();
            for (Field field : declaredFields) {
                InjectByName annotation = field.getAnnotation(InjectByName.class);
                if (annotation != null) {
                    field.setAccessible(true);
                    //通过注解中的bean name寻找注入的值
                    //这里optional类没有发挥它自己的函数式优势,因为我觉得在lambda表达式里面写异常处理属实不好看
                    //借用在Stack overflow看的一句话,Oracle用受检异常把lambda玩砸了
                    Object v = iocContainer.getBean(annotation.value()).get();
                    if (v != null) {
                        field.set(o, v);
                    }else{
                        //同样找不到就抛异常
                        throw new RuntimeException("not find "+field.getType());
                    }
                }
            }
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        }
    }

对函数按类型注入

 /**
     * 这个函数必须是setter函数
     * @param o 要被注入的类
     * @author dreamlike_ocean
     */
    private void InjectMethod(Object o){
        Method[] declaredMethods = o.getClass().getDeclaredMethods();
        try {
            for (Method method : declaredMethods) {
                //获取添加注解的函数
                if (method.getAnnotation(InjectByType.class) != null) {
                    //获取参数列表
                    Class<?>[] parameterTypes = method.getParameterTypes();
                    method.setAccessible(true);
                    int i = method.getParameterCount();
                    //为储存实参做准备
                    Object[] param = new Object[i];
                    //变量重用,现在它代表当前下标了
                    i=0;
                    for (Class<?> parameterType : parameterTypes) {
                        List<?> list = iocContainer.getBeanByType(parameterType);
                        if(list.size() == 0){
                            throw new RuntimeException("not find "+parameterType+"。method :"+method+"class:"+o.getClass());
                        }
                        if (list.size()!=1){
                            throw new RuntimeException("too many");
                        }
                        //暂时存储实参
                        param[i++] = list.get(0);
                    }
                    //调用对应实例的函数
                    method.invoke(o, param);
                }
            }
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InvocationTargetException e) {
            e.printStackTrace();
        }
    }

你会发现上面都是私有方法,因为我想对外暴露一个简洁的API

  /**
     * 对字段依次进行按类型注入和按名称注入
     * 再对setter方法注入
     * @author dreamlike_ocean
     */
    public void inject(){
        iocContainer.getBeans().forEach(o -> {
            InjectFieldByType(o);
            InjectFieldByName(o);
            InjectMethod(o);
        });
    }

测试

做好了,来让我们测一测

@FactoryBean
class BeanFactory{
    public BeanFactory(){}
    @Produce
    public LoginUser register(A a,B b,C c){
        return new LoginUser();
    }
}


@Part("testA")
class A{
   // @InjectByType
    private B b;
    public A(){

    }

    public B getB() {
        return b;
    }
    @InjectByType
    public void setB(B b) {
        this.b = b;
    }
}
@Part
class B{
    private UUID uuid;
public B(){
    uuid = UUID.randomUUID();
}

    public UUID getUuid() {
        return uuid;
    }
}
@Part
class C{
    @InjectByType
    private A a;
    public C(){
    }

    public A getA() {
        return a;
    }
}

测试方法1

@Test
public void test(){
 IOCContainer container = new IOCContainer();
       DI di = new DI(container);
       di.inject();
       System.out.println(container.getBeanByType(A.class).get(0).getB().getUuid());
       System.out.println(container.getBeanByType(B.class).get(0).getUuid());
}

测试方法2

   @Test
public void test(){
       IOCContainer container = new IOCContainer();
       DI di = new DI(container);
       di.inject();
       container.getContext().forEach((k,v)-> System.out.println(k));


}

好了这就可以了

附上工程结构