聊聊如何实现一个支持键值对的SPI

228 阅读4分钟

前言

如果用过JDK提供的SPI机制的朋友,大概就会知道它无法按需加载。之前写过一篇文章聊聊基于jdk实现的spi如何与spring整合实现依赖注入。利用spring的依赖注入来实现spi按需加载,这种方案就是要借用spring。今天我们在聊聊另外一种实现方式,就是我们自己手写一个

实现思路

整体思路和jdk实现spi差不多,如果对jdk实现的spi不了解,可以查看我之前写的文章java之spi机制简介。差别就是我们在配置文件是以key-value的形式存在,形如

springMysql=com.github.lybgeek.dialect.mysql.SpringMysqlDialect

实现逻辑

1、约定好要进行解析的目录,比如META-INF/services/

 private static final String SERVICE_DIRECTORY = "META-INF/services/";

2、约定好要解析的文件名命名,比如

com.github.lybgeek.dialect.SpringSqlDialect

3、约定好文件内容格式,比如

springMysql=com.github.lybgeek.dialect.mysql.SpringMysqlDialect

4、获取约定好的目录,解析文件,并将相应内容放入缓存

 /**
     * Load files under SERVICE_DIRECTORY.
     */
    private void loadDirectory(final Map<String, Class<?>> classes) {
        String fileName = SERVICE_DIRECTORY + clazz.getName();
        try {
            ClassLoader classLoader = ExtensionLoader.class.getClassLoader();
            Enumeration<URL> urls = classLoader != null ? classLoader.getResources(fileName)
                    : ClassLoader.getSystemResources(fileName);
            if (urls != null) {
                while (urls.hasMoreElements()) {
                    URL url = urls.nextElement();
                    loadResources(classes, url);
                }
            }
        } catch (IOException t) {
            log.error("load extension class error {}", fileName, t);
        }
    }

    private void loadResources(final Map<String, Class<?>> classes, final URL url) throws IOException {
        try (InputStream inputStream = url.openStream()) {
            Properties properties = new Properties();
            properties.load(inputStream);
            properties.forEach((k, v) -> {
                String name = (String) k;
                String classPath = (String) v;
                if (StringUtils.isNotBlank(name) && StringUtils.isNotBlank(classPath)) {
                    try {
                        loadClass(classes, name, classPath);
                    } catch (ClassNotFoundException e) {
                        throw new IllegalStateException("load extension resources error", e);
                    }
                }
            });
        } catch (IOException e) {
            throw new IllegalStateException("load extension resources error", e);
        }
    }

    private void loadClass(final Map<String, Class<?>> classes,
                           final String name, final String classPath) throws ClassNotFoundException {
        Class<?> subClass = Class.forName(classPath);
        if (!clazz.isAssignableFrom(subClass)) {
            throw new IllegalStateException("load extension resources error," + subClass + " subtype is not of " + clazz);
        }
        Activate annotation = subClass.getAnnotation(Activate.class);
        if (annotation == null) {
            throw new IllegalStateException("load extension resources error," + subClass + " with Activate annotation");
        }
        Class<?> oldClass = classes.get(name);
        if (oldClass == null) {
            classes.put(name, subClass);
        } else if (oldClass != subClass) {
            throw new IllegalStateException("load extension resources error,Duplicate class " + clazz.getName() + " name " + name + " on " + oldClass.getName() + " or " + subClass.getName());
        }
    }


5、根据key,去缓存查找相应的类实例

 public T getActivate(final String name) {
        if (StringUtils.isBlank(name)) {
            throw new NullPointerException("get Activate name is null");
        }
        Holder<Object> objectHolder = cachedInstances.get(name);
        if (objectHolder == null) {
            cachedInstances.putIfAbsent(name, new Holder<>());
            objectHolder = cachedInstances.get(name);
        }
        Object value = objectHolder.getValue();
        if (value == null) {
            synchronized (cachedInstances) {
                value = objectHolder.getValue();
                if (value == null) {
                    value = createExtension(name);
                    objectHolder.setValue(value);
                }
            }
        }
        return (T) value;
    }

核心代码

@Slf4j
@SuppressWarnings("all")
public final class ExtensionLoader<T> {

    private static final String SERVICE_DIRECTORY = "META-INF/services/";

    private static final Map<Class<?>, ExtensionLoader<?>> LOADERS = new ConcurrentHashMap<>();

    private final Class<T> clazz;

    private final Holder<Map<String, Class<?>>> cachedClasses = new Holder<>();

    private final Map<String, Holder<Object>> cachedInstances = new ConcurrentHashMap<>();

    private final Map<Class<?>, Object> ActivateInstances = new ConcurrentHashMap<>();

    private String cachedDefaultName;

    /**
     * Instantiates a new Extension loader.
     *
     * @param clazz the clazz.
     */
    private ExtensionLoader(final Class<T> clazz) {
        this.clazz = clazz;
        if (clazz != ExtensionFactory.class) {
            ExtensionLoader.getExtensionLoader(ExtensionFactory.class).getExtensionClasses();
        }
    }

    /**
     * Gets extension loader.
     *
     * @param <T>   the type parameter
     * @param clazz the clazz
     * @return the extension loader.
     */
    public static <T> ExtensionLoader<T> getExtensionLoader(final Class<T> clazz) {
        if (clazz == null) {
            throw new NullPointerException("extension clazz is null");
        }
        if (!clazz.isInterface()) {
            throw new IllegalArgumentException("extension clazz (" + clazz + ") is not interface!");
        }
        if (!clazz.isAnnotationPresent(SPI.class)) {
            throw new IllegalArgumentException("extension clazz (" + clazz + ") without @" + SPI.class + " Annotation");
        }
        ExtensionLoader<T> extensionLoader = (ExtensionLoader<T>) LOADERS.get(clazz);
        if (extensionLoader != null) {
            return extensionLoader;
        }
        LOADERS.putIfAbsent(clazz, new ExtensionLoader<>(clazz));
        return (ExtensionLoader<T>) LOADERS.get(clazz);
    }

    /**
     * Gets default Activate.
     *
     * @return the default Activate.
     */
    public T getDefaultActivate() {
        getExtensionClasses();
        if (StringUtils.isBlank(cachedDefaultName)) {
            return null;
        }
        return getActivate(cachedDefaultName);
    }

    /**
     * Gets Activate.
     *
     * @param name the name
     * @return the Activate.
     */
    public T getActivate(final String name) {
        if (StringUtils.isBlank(name)) {
            throw new NullPointerException("get Activate name is null");
        }
        Holder<Object> objectHolder = cachedInstances.get(name);
        if (objectHolder == null) {
            cachedInstances.putIfAbsent(name, new Holder<>());
            objectHolder = cachedInstances.get(name);
        }
        Object value = objectHolder.getValue();
        if (value == null) {
            synchronized (cachedInstances) {
                value = objectHolder.getValue();
                if (value == null) {
                    value = createExtension(name);
                    objectHolder.setValue(value);
                }
            }
        }
        return (T) value;
    }

    public Set<String> getSupportedExtensions() {
        Map<String, Class<?>> clazzes = getExtensionClasses();
        return Collections.unmodifiableSet(new TreeSet<>(clazzes.keySet()));
    }

    @SuppressWarnings("unchecked")
    private T createExtension(final String name) {
        Class<?> aClass = getExtensionClasses().get(name);
        if (aClass == null) {
            throw new IllegalArgumentException("name is error");
        }
        Object o = ActivateInstances.get(aClass);
        if (o == null) {
            try {
                ActivateInstances.putIfAbsent(aClass, aClass.newInstance());
                o = ActivateInstances.get(aClass);
            } catch (InstantiationException | IllegalAccessException e) {
                throw new IllegalStateException("Extension instance(name: " + name + ", class: "
                        + aClass + ")  could not be instantiated: " + e.getMessage(), e);

            }
        }
        return (T) o;
    }

    /**
     * Gets extension classes.
     *
     * @return the extension classes
     */
    public Map<String, Class<?>> getExtensionClasses() {
        Map<String, Class<?>> classes = cachedClasses.getValue();
        if (classes == null) {
            synchronized (cachedClasses) {
                classes = cachedClasses.getValue();
                if (classes == null) {
                    classes = loadExtensionClass();
                    cachedClasses.setValue(classes);
                }
            }
        }
        return classes;
    }

    private Map<String, Class<?>> loadExtensionClass() {
        SPI annotation = clazz.getAnnotation(SPI.class);
        if (annotation != null) {
            String value = annotation.value();
            if (StringUtils.isNotBlank(value)) {
                cachedDefaultName = value;
            }
        }
        Map<String, Class<?>> classes = new HashMap<>(16);
        loadDirectory(classes);
        return classes;
    }

    /**
     * Load files under SERVICE_DIRECTORY.
     */
    private void loadDirectory(final Map<String, Class<?>> classes) {
        String fileName = SERVICE_DIRECTORY + clazz.getName();
        try {
            ClassLoader classLoader = ExtensionLoader.class.getClassLoader();
            Enumeration<URL> urls = classLoader != null ? classLoader.getResources(fileName)
                    : ClassLoader.getSystemResources(fileName);
            if (urls != null) {
                while (urls.hasMoreElements()) {
                    URL url = urls.nextElement();
                    loadResources(classes, url);
                }
            }
        } catch (IOException t) {
            log.error("load extension class error {}", fileName, t);
        }
    }

    private void loadResources(final Map<String, Class<?>> classes, final URL url) throws IOException {
        try (InputStream inputStream = url.openStream()) {
            Properties properties = new Properties();
            properties.load(inputStream);
            properties.forEach((k, v) -> {
                String name = (String) k;
                String classPath = (String) v;
                if (StringUtils.isNotBlank(name) && StringUtils.isNotBlank(classPath)) {
                    try {
                        loadClass(classes, name, classPath);
                    } catch (ClassNotFoundException e) {
                        throw new IllegalStateException("load extension resources error", e);
                    }
                }
            });
        } catch (IOException e) {
            throw new IllegalStateException("load extension resources error", e);
        }
    }

    private void loadClass(final Map<String, Class<?>> classes,
                           final String name, final String classPath) throws ClassNotFoundException {
        Class<?> subClass = Class.forName(classPath);
        if (!clazz.isAssignableFrom(subClass)) {
            throw new IllegalStateException("load extension resources error," + subClass + " subtype is not of " + clazz);
        }
        Activate annotation = subClass.getAnnotation(Activate.class);
        if (annotation == null) {
            throw new IllegalStateException("load extension resources error," + subClass + " with Activate annotation");
        }
        Class<?> oldClass = classes.get(name);
        if (oldClass == null) {
            classes.put(name, subClass);
        } else if (oldClass != subClass) {
            throw new IllegalStateException("load extension resources error,Duplicate class " + clazz.getName() + " name " + name + " on " + oldClass.getName() + " or " + subClass.getName());
        }
    }




    /**
     * The type Holder.
     *
     * @param <T> the type parameter.
     */
    public static class Holder<T> {

        private volatile T value;

        /**
         * Gets value.
         *
         * @return the value
         */
        public T getValue() {
            return value;
        }

        /**
         * Sets value.
         *
         * @param value the value
         */
        public void setValue(final T value) {
            this.value = value;
        }
    }
}

使用示例

1、定义服务接口

@SPI("mysql")
public interface SqlDialect {

    String dialect();

}

2、定义具体实现类

@Activate
public class MysqlDialect implements SqlDialect {
    @Override
    public String dialect() {
        return "mysql";
    }


}

@Activate
public class OracleDialect implements SqlDialect {
    @Override
    public String dialect() {
        return "oracle";
    }


}

3、src/main/resources/下建立/META-INF/services 目录,新增一个以接口命名的文件

4、接口命名的文件填入如下内容

5、加载服务类

   SqlDialect sqlDialect = ExtensionLoader.getExtensionLoader(SqlDialect.class).getActivate("mysql");

6、测试

  @Test
    public void testSpi(){
        SqlDialect sqlDialect = ExtensionLoader.getExtensionLoader(SqlDialect.class).getActivate("mysql");
        Assert.assertEquals("mysql",sqlDialect.dialect());
    }

总结

如果有用过dubbo的spi的朋友,就会发现上面实现的思路基本上就是dubbo的spi简化版。如果是有了解过shenyu网关的spi机制的朋友,就会发现上面的实现思路和shenyu网关基本上是一样了。

demo链接

github.com/lyb-geek/sp…