实现一个简单的IOC容器

246 阅读4分钟

这是我参与11月更文挑战的第6天,活动详情查看:2021最后一次更文挑战

总体思路:通过扫描类中的注解,然后利用反射实现类实例存在在Map当中。

具体实现:

细节方面都有注释。

先是定义一个类用来保存bean的基本信息。

public interface XBeanDefinition {

    final static String SINGLETON = "singleton"; //单例
    final static String  PROTOTYPE = "prototype";   //原型

    Class<?> getBeanClass();

    boolean isSingleton();

    boolean isPrototype();

    String getInitMethodName();

}
public class GenericBeanDefinition implements XBeanDefinition{
    //初始化方法名
    private String initMethodName;
    //类
    private Class<?> clazz;
    //默认为单例模式
    private String scope = XBeanDefinition.SINGLETON;


    public void setInitMethodName(String initMethodName) {
        this.initMethodName = initMethodName;
    }

    public Class<?> getClazz() {
        return clazz;
    }

    public void setClazz(Class<?> clazz) {
        this.clazz = clazz;
    }

    public Class<?> getBeanClass() {
        return this.clazz;
    }

    public boolean isSingleton() {
        return Objects.equals(this.scope,XBeanDefinition.SINGLETON);
    }

    public boolean isPrototype() {
        return Objects.equals(this.scope,XBeanDefinition.PROTOTYPE);
    }

    public String getInitMethodName() {
        return null;
    }
}

实现一个简单的bean工厂。定义的接口只有一个getBean方法。

public interface XBeanFactory {

    Object getBean(String beanName);
}

定义一个专门注册bean的接口。主要是操作bean定义的方法

public interface XBeanDefinitionRegistry {
    //注册bean到工厂里
    void registryBeanDefinition(String beanName, XBeanDefinition beanDefinition) throws Exception ;
    //获得beanDefinition
    XBeanDefinition getBeanDefinition(String beanName);
    //判断beanDefinition
    boolean containBeanDefinition(String beanName);
}

默认工厂的实现

public class DefaultBeanFactory implements XBeanFactory,XBeanDefinitionRegistry{

    public Map<String,Object> beanMap = new ConcurrentHashMap<>(); //模拟spring中的单例池
    private Map<String,XBeanDefinition> beanDefinitionMap = new ConcurrentHashMap<>(); //保存benadefinition

    @Override
    public void registryBeanDefinition(String beanName, XBeanDefinition beanDefinition)throws Exception  {
        Objects.requireNonNull(beanName,"beanName不能为空");
        if(beanDefinitionMap.containsKey(beanName))
        {
            throw new Exception("已存在【"+beanName+ "】的bean定义"+getBeanDefinition(beanName));
        }
        beanDefinitionMap.put(beanName,beanDefinition);
    }

    @Override
    public XBeanDefinition getBeanDefinition(String beanName) {

        return beanDefinitionMap.get(beanName);
    }

    @Override
    public boolean containBeanDefinition(String beanName) {

        return beanDefinitionMap.containsKey(beanName);
    }

    /**
     * 获得bean
     * @param beanName
     * @return
     */
    @Override
    public Object getBean(String beanName) {
        try {
            return doGetBean(beanName);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }
    private Object doGetBean(String beanName) throws Exception
    {
        Objects.requireNonNull(beanName,"beanName不能为空");
        Object instance = beanMap.get(beanName);
        if(instance != null) //先判断单例池里是否存在
        {

            return instance;
        }

        //先获得beanDefinition
        XBeanDefinition beanDefinition = beanDefinitionMap.get(beanName);
        Class<?> clazz = beanDefinition.getBeanClass();
        //创建实例对象
        instance = clazz.newInstance();
        String methodName = beanDefinition.getInitMethodName();
        if(null != methodName)
        {
            //通过反射执行初始化方法
            Method method = clazz.getMethod(methodName, null);
            method.invoke(instance,null);
        }
        if(beanDefinition.isSingleton())
        {
            beanMap.put(beanName,instance);
        }
        return instance;
    }
}

定义几个简单的注解

@Target(ElementType.FIELD) //作用在属性上
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XAutowired {
}
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XBean {
    String value() default ""; //bean id

    String initMethod() default ""; //指定初始化方法名
}
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XComponent {
    String value() default "";
}
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XConfiguration {
    String value() default "";
}

实现一个简单的启动类。通过读取配置文件中的扫描类的路径。然后通过反射实例对象,存放在map当中。

public class XSpringApplication {

    //默认为defaultBeanFactory
    private static DefaultBeanFactory beanFactory = new DefaultBeanFactory();
    /**
     * 初始化方法
     */
    public static void run(String configFilePath)
    {
        //配置文件加载到内存中
        ConfigUtil.loadConfig(configFilePath);
        //获得需要扫描的路径
        String scanPackage = ConfigUtil.contextConfig.getProperty("scan-package");
        //扫描类
        ScannerUtil.doScanner(scanPackage);

        resolve(ScannerUtil.classNameList);
    }

    //解析类,看是否有注解
    private static void resolve(List<String> classNameList)
    {
        if(classNameList.isEmpty())
        {
            return;
        }
        try
        {
            for(String className : classNameList)
            {
                Class<?> clazz = Class.forName(className);
                if(clazz.getDeclaredAnnotations().length != 0) //判读当前类是否有注解
                {

                    Object instance = clazz.newInstance();
                    Method[] methods = clazz.getMethods();//获得所有方法
                    Field[] fields = clazz.getFields();//获取所有属性
                    String beanName = "";
                    if(clazz.isAnnotationPresent(XComponent.class))
                    {
                        beanName = clazz.getAnnotation(XComponent.class).value();

                        if(beanName.equals("")) //如果beanName为""
                        {
                            beanName = toLowerFirstCase(clazz.getSimpleName());
                        }
                    }
                    else if(clazz.isAnnotationPresent(XConfiguration.class))
                    {
                        beanName = toLowerFirstCase(clazz.getSimpleName());
                    }
                    GenericBeanDefinition beanDefinition = new GenericBeanDefinition();
                    beanDefinition.setClazz(clazz);
                    //注册bean
                    beanFactory.registryBeanDefinition(beanName,beanDefinition);
                    beanFactory.beanMap.put(beanName,instance);
                    listMethod(methods,instance); //遍历所有方法
                    listFields(fields,instance); //遍历所有属性
                }

            }
        }
        catch (Exception e)
        {
            e.printStackTrace();
        }

    }
    public static Object getBean(String beanName)
    {
        return beanFactory.getBean(beanName);
    }

    /**
     * 获取类的首字母小写的名称
     *
     * @param className ClassName
     * @return java.lang.String
     */
    private static String toLowerFirstCase(String className) {
        char[] charArray = className.toCharArray();
        charArray[0] += 32;
        return String.valueOf(charArray);
    }
    private static String getBeanName(String className)
    {
        String[] arr = className.split("\.");
        return toLowerFirstCase(arr[arr.length-1]);
    }

    private static void listMethod(Method[] methods  , Object instance)
    {
        try{
            for(Method method : methods) //遍历所有方法,判断是否使用了@Bane注解
            {
                if(method.isAnnotationPresent(XBean.class)) //判断方法是否用XBean注解
                {
                    //获得返回值类型
                    String type = method.getGenericReturnType().getTypeName();
                    String beanName = "";
                    String val = method.getAnnotation(XBean.class).value();
                    if(!"".equals(val))
                    {
                        beanName = val;
                    }
                    else
                    {
                        beanName = getBeanName(type);
                    }
                    Class classN = Class.forName(type);
                    GenericBeanDefinition beanDefinition = new GenericBeanDefinition();
                    beanDefinition.setClazz(classN);
                    beanFactory.registryBeanDefinition(beanName,beanDefinition);
                    Object bean = method.invoke(instance, method.getParameters());//通过反射调用方法示例出bean
                    beanFactory.beanMap.put(beanName,bean);//直接加入beanMap中

                }
            }
        }
        catch (Exception e)
        {
            e.printStackTrace();
        }

    }
    private static void listFields(Field[] fields , Object instance) throws Exception
    {
        for(Field field : fields)
        {
            if(!field.isAnnotationPresent(XAutowired.class))
            {
                continue;
            }
            else
            {
                field.setAccessible(true);
                Class<?> aClass = field.getType(); //获得需要注入属性的class对象

                String beanName = toLowerFirstCase(aClass.getSimpleName());

                //先判断aclass是否已经注册
                if(!beanFactory.containBeanDefinition(beanName))
                {

                    throw new  Exception("该属性对象没有在spring容器中注册");
                }
                //该aclass已经注册在spring中
                //判断spring容器中是否已经存在该aclass的示例对象,如果不存在则会实例出一个对象

                Object bean = beanFactory.getBean(beanName);

                field.set(instance,bean);//注入属性
            }
        }
    }
}

另外还有一些工具类,主要是用来读取配置文件信息。

public class ConfigUtil {
    public static Properties contextConfig = new Properties();


    /**
     * 加载配置文件到内存中
     * @param configPath 配置文件路径
     */
    public static void loadConfig(String configPath)
    {
        InputStream inputStream = ConfigUtil.class.getClassLoader().getResourceAsStream(configPath);
        try {
            contextConfig.load(inputStream);
        } catch (IOException e) {
            e.printStackTrace();
        }
        finally {
            if(inputStream != null)
            {
                try {
                    inputStream.close(); //关闭流
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }
}
public class ScannerUtil {

    public static List<String> classNameList = new ArrayList<>();
    /**
     * 扫描类
     * @param scanPackage
     */
    public  static void doScanner(String scanPackage)
    {

        URL resourcePath = ScannerUtil.class.getResource("/" + scanPackage.replaceAll("\.", "/"));

        if(null == resourcePath)
        {
            return ;
        }
        File classPath = new File(resourcePath.getFile());

        //遍历
        for(File file : classPath.listFiles())
        {
            if(file.isDirectory()) //文件夹则递归查询
            {
                doScanner(scanPackage+"." + file.getName());
            }
            else
            {
                //反射判断是否有componton注解
                if(!file.getName().endsWith(".class"))
                {
                    continue;
                }
                String className = (scanPackage + "." + file.getName()).replace(".class","");
                classNameList.add(className);
            }

        }

    }
}

测试:

@XComponent(value = "user")
public class User {

    private String uid;

    private String userName;


    public String getUid() {
        return uid;
    }

    public void setUid(String uid) {
        this.uid = uid;
    }

    @Override
    public String toString() {
        return "User{" +
                "uid='" + uid + ''' +
                ", userName='" + userName + ''' +
                '}';
    }

    public String getUserName() {
        return userName;
    }

    public void setUserName(String userName) {
        this.userName = userName;
    }
}

测试@XCompant

public class Test {
    public static void main(String[] args) {
        XSpringApplication.run("application.properties");
        User user = (User)XSpringApplication.getBean("user");

//        ConfigurationTest test = (ConfigurationTest)XSpringApplication.getBean("configurationTest");
//        System.out.println(test.user);
    }

 

 测试XBean,把User类注解注释掉

@XConfiguration
public class ConfigurationTest {

//    @XAutowired
//    public User user;

   @XBean(value = "user")
   public User getUser()
   {
       User user = new User();
       user.setUid("123456");
       user.setUserName("swq");
       return user;
   }

   @XBean
   public Student getStudent()
   {
       return new Student();
   }
}

测试Autowired

@XConfiguration
public class ConfigurationTest {

    @XAutowired
    public User user;

    @XBean(value = "user")
    public User getUser()
    {
        User user = new User();
        user.setUid("123456");
        user.setUserName("swq");
        return user;
    }

    @XBean
    public Student getStudent()
    {
        return new Student();
    }
}
public class Test {
    public static void main(String[] args) {
        XSpringApplication.run("application.properties");
        User user = (User)XSpringApplication.getBean("user");
        System.out.println(user);
        ConfigurationTest test = (ConfigurationTest)XSpringApplication.getBean("configurationTest");
        System.out.println(test.user);
    }
}

配置文件中只配置了扫描类路径

scan-package=com.swq  //需要扫描的类路径

总结:

实现简单的ioc容器主要还是依靠反射。当然spring实现ioc实现存在大量细节处理,要想理解spring的实现还是需要多多阅读源码。