手撸 Spring 简易版 IOC

29 阅读4分钟

手撸 Spring 简易版 IOC

一、核心目标

基于 Spring 6.x 的核心思想,实现最简版 IOC 容器,包含:

  1. 注解扫描(@MyComponent、@MyConfiguration);
  2. BeanDefinition 封装;
  3. 单例 Bean 实例化;
  4. @MyAutowired 依赖注入;
  5. 兼容 Spring 6.x 的核心流程(简化版)。

二、完整实现代码

步骤 1:定义核心注解(模拟 Spring 6.x 注解)

import java.lang.annotation.*;

// 对应 Spring 的 @Component
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public @interface MyComponent {
    String value() default "";
}

// 对应 Spring 的 @Autowired
@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
public @interface MyAutowired {
}

// 对应 Spring 的 @Configuration + @ComponentScan
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public @interface MyConfiguration {
    String scanPackage() default "";
}

步骤 2:定义 BeanDefinition(核心元信息类)

/** * 模拟 Spring 的 BeanDefinition:存储 Bean 的元信息 */
public class MyBeanDefinition {
    // Bean 对应的类
    private Class<?> beanClass;
    // Bean 作用域(默认单例)
    private String scope = "singleton";
    // 是否懒加载
    private boolean lazyInit = false;

    public MyBeanDefinition(Class<?> beanClass) {
        this.beanClass = beanClass;
    }

    // getter/setter
    public Class<?> getBeanClass() { return beanClass; }
    public void setBeanClass(Class<?> beanClass) { this.beanClass = beanClass; }
    public String getScope() { return scope; }
    public void setScope(String scope) { this.scope = scope; }
    public boolean isLazyInit() { return lazyInit; }
    public void setLazyInit(boolean lazyInit) { this.lazyInit = lazyInit; }
}

步骤 3:实现核心 IOC 容器(MyApplicationContext)

import java.io.File;
import java.lang.reflect.Field;
import java.net.URL;
import java.util.HashMap;
import java.util.Map;

/** * 简易版 Spring 6.x IOC 容器 * 核心流程:扫描→注册 BeanDefinition→实例化 Bean→依赖注入 */
public class MyApplicationContext {
    // 存储 BeanDefinition(对应 Spring 的 beanDefinitionMap)
    private Map<String, MyBeanDefinition> beanDefinitionMap = new HashMap<>();
    // 单例池(对应 Spring 的 singletonObjects)
    private Map<String, Object> singletonObjects = new HashMap<>();
    // 配置类
    private Class<?> configClass;

    // 构造方法:初始化容器
    public MyApplicationContext(Class<?> configClass) {
        this.configClass = configClass;
        // 1. 扫描包(优化:支持中文路径+递归),注册 BeanDefinition
        scanAndRegisterBeanDefinitions();
        // 2. 实例化所有非懒加载的单例 Bean
        instantiateSingletons();
    }

    // 核心优化:扫描包(支持中文路径 + 递归扫描子包)
    private void scanAndRegisterBeanDefinitions() {
        // 获取扫描包路径
        MyConfiguration configAnnotation = configClass.getAnnotation(MyConfiguration.class);
        String scanPackage = configAnnotation.scanPackage();
        if (scanPackage.isEmpty()) {
            throw new RuntimeException("请配置 scanPackage 路径!");
        }

        // 转换包路径为文件路径(com.example.中文包 -> com/example/中文包)
        String packagePath = scanPackage.replace(".", "/");
        ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
        URL resource = classLoader.getResource(packagePath);
        if (resource == null) {
            throw new RuntimeException("扫描包不存在:" + scanPackage);
        }

        try {
            // 核心优化1:解码中文路径(解决中文包名乱码问题)
            String decodedPath = URLDecoder.decode(resource.getFile(), StandardCharsets.UTF_8.name());
            File packageDir = new File(decodedPath);

            // 核心优化2:递归扫描目录(包含所有子包)
            recursiveScan(packageDir, scanPackage);
        } catch (UnsupportedEncodingException e) {
            throw new RuntimeException("中文路径解码失败:" + e.getMessage(), e);
        }
    }

    // 递归扫描目录的核心方法
    private void recursiveScan(File dir, String basePackage) {
        // 1. 校验目录是否存在
        if (!dir.exists() || !dir.isDirectory()) {
            return;
        }

        // 2. 遍历目录下的所有文件/子目录
        File[] files = dir.listFiles();
        if (files == null) {
            return;
        }

        for (File file : files) {
            if (file.isDirectory()) {
                // 若是子目录:递归扫描(更新基础包名,如 com.example -> com.example.子包)
                String subPackage = basePackage + "." + file.getName();
                recursiveScan(file, subPackage);
            } else if (file.getName().endsWith(".class")) {
                // 若是 .class 文件:处理并注册 BeanDefinition
                processClassFile(file, basePackage);
            }
        }
    }

    // 处理 .class 文件,生成并注册 BeanDefinition
    private void processClassFile(File classFile, String basePackage) {
        try {
            // 1. 拼接类全限定名(如 com.example.中文包.UserService)
            String className = getClassName(classFile, basePackage);
            // 2. 加载类
            Class<?> clazz = Class.forName(className);
            // 3. 判断是否有 @MyComponent 注解
            if (clazz.isAnnotationPresent(MyComponent.class)) {
                MyComponent component = clazz.getAnnotation(MyComponent.class);
                // 生成 Bean 名称(优先注解值,否则类名首字母小写)
                String beanName = component.value().isEmpty()
                        ? toLowerCaseFirstChar(clazz.getSimpleName())
                        : component.value();
                // 创建并注册 BeanDefinition
                MyBeanDefinition beanDefinition = new MyBeanDefinition(clazz);
                beanDefinitionMap.put(beanName, beanDefinition);
            }
        } catch (ClassNotFoundException e) {
            throw new RuntimeException("加载类失败:" + e.getMessage(), e);
        }
    }

    // 工具方法:从 .class 文件路径拼接类全限定名(兼容中文路径)
    private String getClassName(File classFile, String basePackage) {
        // 获取文件绝对路径(已解码为 UTF-8)
        String filePath = classFile.getAbsolutePath();
        // 截取包路径部分(如 D:/project/classes/com/example/中文包/UserService.class -> com/example/中文包/UserService)
        String packagePath = basePackage.replace(".", File.separator);
        int startIndex = filePath.indexOf(packagePath);
        int endIndex = filePath.lastIndexOf(".class");

        // 转换为类全限定名(com/example/中文包/UserService -> com.example.中文包.UserService)
        String className = filePath.substring(startIndex, endIndex)
                .replace(File.separator, ".");
        return className;
    }

    // 实例化所有非懒加载的单例 Bean(原有逻辑不变)
    private void instantiateSingletons() {
        for (String beanName : beanDefinitionMap.keySet()) {
            MyBeanDefinition beanDefinition = beanDefinitionMap.get(beanName);
            // 只处理单例 + 非懒加载的 Bean
            if ("singleton".equals(beanDefinition.getScope()) && !beanDefinition.isLazyInit()) {
                // 创建 Bean 实例
                Object bean = createBean(beanName, beanDefinition);
                // 存入单例池
                singletonObjects.put(beanName, bean);
            }
        }
    }

    // 创建 Bean 实例 + 依赖注入(原有逻辑不变)
    private Object createBean(String beanName, MyBeanDefinition beanDefinition) {
        Class<?> beanClass = beanDefinition.getBeanClass();
        try {
            // 1. 通过无参构造创建实例
            Object beanInstance = beanClass.getDeclaredConstructor().newInstance();
            // 2. 处理 @MyAutowired 依赖注入
            populateBean(beanInstance);
            return beanInstance;
        } catch (Exception e) {
            throw new RuntimeException("创建 Bean 失败:" + beanName, e);
        }
    }

    // 依赖注入(原有逻辑不变)
    private void populateBean(Object beanInstance) {
        Class<?> clazz = beanInstance.getClass();
        Field[] fields = clazz.getDeclaredFields();
        for (Field field : fields) {
            if (field.isAnnotationPresent(MyAutowired.class)) {
                field.setAccessible(true);
                Class<?> fieldType = field.getType();
                // 根据类型从单例池找依赖 Bean
                Object dependencyBean = getBeanByType(fieldType);
                if (dependencyBean == null) {
                    throw new RuntimeException("未找到依赖 Bean:" + fieldType.getName());
                }
                try {
                    // 注入依赖
                    field.set(beanInstance, dependencyBean);
                } catch (IllegalAccessException e) {
                    throw new RuntimeException("注入依赖失败:" + field.getName(), e);
                }
            }
        }
    }

    // 对外提供获取 Bean 的方法(按类型)
    public <T> T getBean(Class<T> clazz) {
        for (Map.Entry<String, Object> entry : singletonObjects.entrySet()) {
            if (clazz.isInstance(entry.getValue())) {
                return (T) entry.getValue();
            }
        }
        throw new RuntimeException("未找到 Bean:" + clazz.getName());
    }

    // 对外提供获取 Bean 的方法(按名称)
    public Object getBean(String beanName) {
        Object bean = singletonObjects.get(beanName);
        if (bean == null) {
            throw new RuntimeException("未找到 Bean:" + beanName);
        }
        return bean;
    }

    // 工具方法:根据类型查找 Bean
    private Object getBeanByType(Class<?> type) {
        for (Object bean : singletonObjects.values()) {
            if (type.isInstance(bean)) {
                return bean;
            }
        }
        return null;
    }

    // 工具方法:首字母小写
    private String toLowerCaseFirstChar(String str) {
        if (str == null || str.isEmpty()) return str;
        char[] chars = str.toCharArray();
        chars[0] = Character.toLowerCase(chars[0]);
        return new String(chars);
    }
}

步骤 4:测试代码

1. 编写业务类
// Dao 层
@MyComponent
public class UserDao {
    public void queryUser() {
        System.out.println("Spring 6.x 简易 IOC:查询用户信息");
    }
}

// Service 层(依赖 UserDao)
@MyComponent
public class UserService {
    @MyAutowired
    private UserDao userDao;

    public void getUser() {
        userDao.queryUser();
        System.out.println("Spring 6.x 简易 IOC:UserService 处理请求");
    }
}
2. 编写配置类
@MyConfiguration(scanPackage = "com.example.spring6.ioc.demo")
public class AppConfig {
}
3. 编写测试类
public class MyIocTest {
    public static void main(String[] args) {
        // 1. 创建简易 IOC 容器
        MyApplicationContext context = new MyApplicationContext(AppConfig.class);
        // 2. 获取 Bean
        UserService userService = context.getBean(UserService.class);
        // 3. 调用方法
        userService.getUser();
    }
}

运行结果

Spring 6.x 简易 IOC:查询用户信息
Spring 6.x 简易 IOC:UserService 处理请求

三、简易版 vs Spring 6.x 原生源码对比

简易版 IOC 组件Spring 6.x 原生组件作用
MyBeanDefinitionBeanDefinition存储 Bean 元信息
beanDefinitionMapDefaultListableBeanFactory#beanDefinitionMap存储 BeanDefinition
singletonObjectsDefaultSingletonBeanRegistry#singletonObjects单例池
MyApplicationContextAnnotationConfigApplicationContext核心容器
populateBean()AbstractAutowireCapableBeanFactory#populateBean依赖注入
createBean()AbstractAutowireCapableBeanFactory#createBean创建 Bean 实例

📌 关注我,每天5分钟,带你从 Java 小白变身编程高手!
👉 点赞 + 关注+私信"IOC源码"获取手撸源码,让更多小伙伴一起进步!