spring框架之手写IOC

181 阅读4分钟

Spring源码之IOC demo

前言

最近在学习spring源码,为了理解spring ioc的原理,跟着教程自己实现了一个简易的ioc,代码参考了b站up主“楠哥教你学java”。下面上代码和我自己的一些理解

实现思路
①创建一个maven工程,注意:该工程是没有引入spring框架相关的jar包的,故注解类似@Autowired,@Component这些统统是不存在的
。这里只引入了lombok

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>org.example</groupId>
    <artifactId>untitled</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
         <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
    </properties>

    <dependencies>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <version>1.18.12</version>
            <scope>provided</scope>
        </dependency>
    </dependencies>

</project>

②自定义注解@Autowired,@Value,@Component,@Qualifiler,在这个demo里,这些注解的作用类似于一个标记,在后面对类和方法进行扫描时,可以根据扫描到的注解进行相应的操作。注解例:

@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
public @interface Autowired {
}

③定义类Count和Order,并在类上添加注解

@Data
@Component
public class Order {
    @Value("xxx123")
    private String orderId;
    @Value("1000.5")
    private Float price;
}

④接下来就是最重要的扫描创建bean的过程,由于代码有点大,就不一一讲解了,给出大概的思路:

  • 扫描指定包路径中所有的.class文件,通过反射机制获取每个class类的Class对象,通过Class对象获取每个类上的注解,如果存在@Component注解,则通过Class对象实例化其对应的类,存放在Map中,并将这个Class对象存放到Set集合中;
  • 遍历上述Set集合,获取每个Class对象的属性,检查这些属性上是否有@Value或@Autowired注解,有则通过set方法(引入lombok的原因)对这些属性进行赋值

下面上完整的代码

Bean

package com.entity;
import annonation.Autowired;
import annonation.Component;
import annonation.Qualifier;
import annonation.Value;
import entity.Order;
import lombok.Data;

@Data
@Component
public class Account {
    @Value("1")
    private Integer id;
    @Value("张三")
    private String name;
    @Value("22")
    private Integer age;
    @Autowired
    @Qualifier("myOrder")
    private Order order;
}


package com.entity;
import annonation.Component;
import annonation.Value;
import lombok.Data;



@Data
@Component
public class Order {
    @Value("xxx123")
    private String orderId;
    @Value("1000.5")
    private Float price;
}

注解

package annonation;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
public @interface Autowired {
}



package annonation;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public @interface Component {
    String value() default "";
}



package annonation;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
public @interface Qualifier {
    String value() default "";
}



package annonation;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
public @interface Value {
    String value() default "";
}

关键的工具类

import java.io.File;
import java.io.FileFilter;
import java.io.IOException;
import java.net.JarURLConnection;
import java.net.URL;
import java.net.URLDecoder;
import java.util.Enumeration;
import java.util.LinkedHashSet;
import java.util.Set;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;

public class MyTool {

    public static Set<Class<?>> getClasses(String pack){
        Set<Class<?>> classes=new LinkedHashSet<>();
        boolean recursive=true;
        String packageName=pack;
        String packageDirName=packageName.replace('.','/');
        Enumeration<URL> dirs;
        try {
            dirs=Thread.currentThread().getContextClassLoader().getResources(packageDirName);
            while(dirs.hasMoreElements()){
                URL url=dirs.nextElement();
                String protocol=url.getProtocol();
                if("file".equals(protocol)){
                    String filePath= URLDecoder.decode(url.getFile(),"UTF-8");
                    findClassesInPackageByFile(packageName,filePath,recursive,classes);
                }else if("jar".equals(protocol)){
                    JarFile jar;
                    try {
                        jar=((JarURLConnection)url.openConnection()).getJarFile();
                        Enumeration<JarEntry> entries=jar.entries();
                        findClassesInPackageByJar(packageName,entries,packageDirName,recursive,classes);
                    }catch (IOException e){
                        e.printStackTrace();
                    }
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        return classes;
    }

    public static void findClassesInPackageByFile(String packageName, String packagePath, final  boolean recursive, Set<Class<?>>classes){
        File dir=new File(packagePath);
        //路径下的包必须存在,且为目录,不能为文件
        if(!dir.exists()||!dir.isDirectory()){
            return;
        }
        //对目录中的文件进行过滤,将.class文件和目录留下来
        File[]dirfiles=dir.listFiles(new FileFilter() {
            @Override
            public boolean accept(File file) {
                return (recursive&&file.isDirectory())||(file.getName().endsWith(".class"));
            }
        });

        for (File file : dirfiles) {
            //如果是目录,则继续往下搜索
            if(file.isDirectory()){
                findClassesInPackageByFile(packageName+"."+file.getName(),file.getAbsolutePath(),recursive,classes);
            }//为java类文件,则去掉.class留下类名,跟包名组成完成路径创建Class对象添加到set中去
            else{
                String className=file.getName().substring(0,file.getName().length()-6);
                try {
                    classes.add(Thread.currentThread().getContextClassLoader().loadClass(packageName+'.'+className));
                } catch (ClassNotFoundException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    public static void findClassesInPackageByJar(String packageName, Enumeration<JarEntry> entries,String packageDirName,final  boolean recursive,Set<Class<?>>classes){
        while(entries.hasMoreElements()){
            JarEntry entry=entries.nextElement();
            String name=entry.getName();

            if(name.charAt(0)=='/'){
                name=name.substring(1);
            }
            if(name.startsWith(packageDirName)){
                int idx=name.lastIndexOf('/');
                if(idx!=-1){
                    packageName=name.substring(0,idx).replace('/','.');
                }
                if((idx!=1)||recursive){
                    if(name.endsWith(".class")&&!entry.isDirectory()){
                        String className=name.substring(packageDirName.length()+1,name.length()-6);
                        try {
                            classes.add(Class.forName(packageName+'.'+className));
                        } catch (ClassNotFoundException e) {
                            e.printStackTrace();
                        }
                    }
                }
            }
        }

    }
}

bean类

import lombok.AllArgsConstructor;
import lombok.Data;

@AllArgsConstructor
@Data
public class BeanDefinition {
    private String beanName;
    private Class beanClass;
}

启动类(由该类来启动IOC进行类的注入)

import annonation.Autowired;
import annonation.Component;
import annonation.Qualifier;
import annonation.Value;

import java.io.File;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.*;

public class MyAnnotationConfigApplicationContext {
    //用来存放bean,保证每个类只有一个bean
    private Map<String, Object> ioc=new HashMap<>();
    private List<String> beanNames=new ArrayList<>();

    public MyAnnotationConfigApplicationContext(String pack) {
        Set<BeanDefinition> beanDefinitions=findBeanDefinitions(pack);
        createObject(beanDefinitions);
        autowireObject(beanDefinitions);
    }

    public Set<BeanDefinition> findBeanDefinitions(String pack){
        Set<Class<?>> classes=MyTool.getClasses(pack);
        Iterator<Class<?>> iterator=classes.iterator();
        Set<BeanDefinition> beanDefinitions=new HashSet<>();
        while(iterator.hasNext()){
            Class<?> clazz=iterator.next();
            Component component=clazz.getAnnotation(Component.class);
            if(component!=null){
                String beanName=component.value();
                if("".equals(beanName)){
                    String className=clazz.getName().replaceAll(clazz.getPackageName()+'.',"");
                    beanName=className.substring(0,1).toLowerCase()+className.substring(1);
                }
                beanDefinitions.add(new BeanDefinition(beanName,clazz));
                beanNames.add(beanName);
            }
        }
        return beanDefinitions;
    }

    public void createObject(Set<BeanDefinition> beanDefinitions){
        Iterator<BeanDefinition> iterator=beanDefinitions.iterator();
        while(iterator.hasNext()){
            BeanDefinition beanDefinition=iterator.next();
            Class clazz=beanDefinition.getBeanClass();
            String beanName=beanDefinition.getBeanName();
            try{
                Object object=clazz.getConstructor().newInstance();
                Field[] declaredFields=clazz.getDeclaredFields();
                for (Field declaredField : declaredFields) {
                    Value valueAnnotation=declaredField.getAnnotation(Value.class);
                    if(valueAnnotation!=null){
                        String value=valueAnnotation.value();
                        String fieldNmae=declaredField.getName();
                        String methodName="set"+fieldNmae.substring(0,1).toUpperCase()+fieldNmae.substring(1);
                        Method method=clazz.getMethod(methodName,declaredField.getType());
                        Object val=null;
                        switch (declaredField.getType().getName()){
                            case"java.lang.Integer":
                                val=Integer.parseInt(value);
                                break;
                            case"java.lang.String":
                                val=value;
                                break;
                            case "java.lang.Float":
                                val=Float.parseFloat(value);
                                break;
                        }
                        method.invoke(object,val);
                    }
                }
                ioc.put(beanName,object);
            } catch (InstantiationException e) {
                e.printStackTrace();
            } catch (InvocationTargetException e) {
                e.printStackTrace();
            } catch (NoSuchMethodException e) {
                e.printStackTrace();
            } catch (IllegalAccessException e) {
                e.printStackTrace();
            }
        }
    }

    public void autowireObject(Set<BeanDefinition> beanDefinitions){
        Iterator<BeanDefinition> iterator=beanDefinitions.iterator();
        while(iterator.hasNext()){
            BeanDefinition beanDefinition=iterator.next();
            Class clazz=beanDefinition.getBeanClass();
            Field[] declaredFields=clazz.getDeclaredFields();
            for (Field declaredField : declaredFields) {
                Autowired autowired=declaredField.getAnnotation(Autowired.class);
                if(autowired!=null){
                    Qualifier qualifier=declaredField.getAnnotation(Qualifier.class);
                    if(qualifier!=null){
                        try{
                            String beanName=qualifier.value();
                            Object bean=getBean(beanName);
                            String fieldName=declaredField.getName();
                            String methodName = "set"+fieldName.substring(0, 1).toUpperCase()+fieldName.substring(1);
                            Method method = clazz.getMethod(methodName, declaredField.getType());
                            Object object = getBean(beanDefinition.getBeanName());
                            method.invoke(object, bean);
                        } catch (NoSuchMethodException e) {
                            e.printStackTrace();
                        } catch (IllegalAccessException e) {
                            e.printStackTrace();
                        } catch (InvocationTargetException e) {
                            e.printStackTrace();
                        }
                    }
                }
            }
        }
    }


    public Object getBean(String beanName){
        return ioc.get(beanName);
    }

    public String[] getBeanDefinitionNames(){
        return beanNames.toArray(new String[0]);
    }

    public Integer getBeanDefinitionCount(){
        return beanNames.size();
    }



}

测试类

public class test {
    public static void main(String[] args) {
        MyAnnotationConfigApplicationContext applicationContext=new MyAnnotationConfigApplicationContext("com.entity");
        System.out.println(applicationContext.getBeanDefinitionCount());
        String[] beanDefinitionNames = applicationContext.getBeanDefinitionNames();
        for (String beanDefinitionName : beanDefinitionNames) {
            System.out.println(beanDefinitionName);
            System.out.println(applicationContext.getBean(beanDefinitionName));
        }
    }
}

运行结果:

image.png

写在最后:

最近刚开始试着写博客,主要分享自己的一些学习心得的学习笔记,可能写得不怎么清晰。大家要是觉得有什么不能理解的,或者觉得我的代码和理解有什么问题的,欢迎在评论区指出,有空我会一一回复的。谢谢大家,我会好好进步的!!!