通过interface自动生成swagger文档并mock数据

91 阅读3分钟

注意修改被扫描的包的路径,目前是com.qbit

import com.google.common.base.Charsets;
import com.google.common.base.Joiner;
import com.google.common.base.Predicates;
import com.google.common.collect.Lists;
import lombok.Getter;
import lombok.var;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanDefinitionHolder;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.*;
import org.springframework.cglib.core.ReflectUtils;
import org.springframework.core.GenericTypeResolver;
import org.springframework.core.MethodParameter;
import org.springframework.core.ResolvableType;
import org.springframework.core.io.Resource;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import org.springframework.core.io.support.ResourcePatternResolver;
import org.springframework.core.type.ClassMetadata;
import org.springframework.core.type.classreading.CachingMetadataReaderFactory;
import org.springframework.core.type.classreading.MetadataReader;
import org.springframework.core.type.classreading.MetadataReaderFactory;
import org.springframework.stereotype.Component;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RestController;

import javax.tools.*;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.net.URI;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.Collections;
import java.util.List;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
 * @author qbit
 */
@Component
public class ControllerRegistry implements BeanDefinitionRegistryPostProcessor {

    private DynamicCompiler compiler;

    @Override
    public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry beanDefinitionRegistry) throws BeansException {
        SrcGenerator srcGenerator = new SrcGenerator();
        new ControllerScanner().controllers()
                .forEach(clazz -> {
                    if (null == compiler) {
                        compiler = new DynamicCompiler(clazz.getClassLoader());
                    }
                    final String className = srcGenerator.implClassName(clazz);
                    final String name = StringUtils.substringAfterLast(className, ".");
                    final Class implClass = compiler.compile(className, srcGenerator.implSrc(clazz));
                    final Object impl = ReflectUtils.newInstance(implClass);
                    BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder.genericBeanDefinition(implClass, () -> impl);
                    BeanDefinition beanDefinition = beanDefinitionBuilder.getBeanDefinition();
                    beanDefinition.setBeanClassName(className);
                    BeanDefinitionHolder holder = new BeanDefinitionHolder(beanDefinition, name);
                    BeanDefinitionReaderUtils.registerBeanDefinition(holder, beanDefinitionRegistry);
                    System.out.println(beanDefinition);
                });
    }

    @Override
    public void postProcessBeanFactory(ConfigurableListableBeanFactory configurableListableBeanFactory) throws BeansException {
        if (null != compiler) {
            configurableListableBeanFactory.setBeanClassLoader(compiler.getClassLoader());
        }
    }

    /**
     * @author qbit
     */
    public static class ControllerScanner {
        private final String packageName = "com.qbit";
        private MetadataReaderFactory metaReader;
        private ClassLoader classLoader;

        Stream<Class<?>> controllers() {
            ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver();
            metaReader = new CachingMetadataReaderFactory();
            Resource[] resources;
            try {
                resources = resolver.getResources("classpath*:" + packageName.replace('.', '/') + "/**/*.class");
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
            classLoader = ClassLoader.getSystemClassLoader();
            return Stream.of(resources)
                    .map(this::getMetadataReader)
                    .map(MetadataReader::getClassMetadata)
                    .map(ClassMetadata::getClassName)
                    .filter(controllerFilter)
                    .map(this::loadClass)
                    ;
        }

        private Class<?> loadClass(String s) {
            try {
                return classLoader.loadClass(s);
            } catch (ClassNotFoundException e) {
                throw new RuntimeException(e);
            }

        }

        private MetadataReader getMetadataReader(Resource resource) {
            try {
                return metaReader.getMetadataReader(resource);
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }

        static final List<String> controllerPostfix = Collections.unmodifiableList(Lists.newArrayList("Controller", "Service", "Listener"));
        private Predicate<String> controllerFilter;

        {
            Predicate<String> packageFilter = s -> s.startsWith(packageName);
            Predicate<String> typeFilter = Predicates.alwaysFalse();
            for (String postfix : controllerPostfix) {
                typeFilter = typeFilter.or(s -> s.endsWith(postfix));
            }
            controllerFilter = packageFilter.and(typeFilter);
        }

        public static void main(String[] args) {
            new ControllerScanner().controllers().forEach(System.out::println);
        }
    }

    /**
     * @author qbit
     */
    public static class DynamicCompiler {

        private final DynamicClassLoader classLoader;
        @Getter
        private final URLClassLoader parentClassloader;

        public ClassLoader getClassLoader() {
            return classLoader;
        }

        public DynamicCompiler(ClassLoader classLoader) {
            this.parentClassloader = (URLClassLoader) classLoader;
            this.classLoader = new DynamicClassLoader(classLoader);
        }

        public Class<?> compile(String className, String src) {
            JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
            DiagnosticCollector<javax.tools.JavaFileObject> diagnosticCollector = new DiagnosticCollector<>();
            var standardFileManager = compiler.getStandardFileManager(diagnosticCollector, null, null);
            var javaFileManager = new JavaFileManager(standardFileManager);
            var files = Lists.newArrayList(new JavaFileObject(className, src));
            //        var options=Lists.newArrayList("-encoding", Charsets.UTF_8.toString());
            var classpath = Joiner.on(File.pathSeparator).join(Stream.of(getParentClassloader().getURLs()).map(URL::getFile).collect(Collectors.toList()));
            var options = Lists.newArrayList("-encoding", Charsets.UTF_8.toString(), "-classpath", classpath);
            var task = compiler.getTask(null, javaFileManager, diagnosticCollector, options, null, files);
            if (task.call()) {
                return javaFileManager.toClass(classLoader);
            } else {
                String message = "compile " + className + " fail!";
                System.err.println(message);
                System.err.println(src);
                diagnosticCollector.getDiagnostics().forEach(diagnostic -> {
                    System.err.println("Code:" + diagnostic.getCode());
                    System.err.println("Kind:" + diagnostic.getKind());
                    System.err.println("Position:" + diagnostic.getPosition());
                    System.err.println("Start Position:" + diagnostic.getStartPosition());
                    System.err.println("End Position:" + diagnostic.getEndPosition());
                    System.err.println("Source:" + diagnostic.getSource());
                    System.err.println("Message:" + diagnostic.getMessage(null));
                    System.err.println("LineNumber:" + diagnostic.getLineNumber());
                    System.err.println("ColumnNumber:" + diagnostic.getColumnNumber());
                });
                throw new RuntimeException(message);
            }
        }

        private class JavaFileObject extends SimpleJavaFileObject {

            private final CharSequence content;

            protected JavaFileObject(String className, CharSequence content) {
                super(URI.create("string:///" + className.replace('.', '/')
                        + Kind.SOURCE.extension), Kind.SOURCE);
                this.content = content;
            }

            @Override
            public CharSequence getCharContent(boolean ignoreEncodingErrors) throws IOException {
                return content;
            }
        }

        private class JavaFileManager extends ForwardingJavaFileManager<javax.tools.JavaFileManager> {

            private JavaClassObject javaFileObject;

            protected JavaFileManager(javax.tools.JavaFileManager fileManager) {
                super(fileManager);
            }

            @Override
            public javax.tools.JavaFileObject getJavaFileForOutput(Location location, String className, javax.tools.JavaFileObject.Kind kind, FileObject sibling) throws IOException {
                javaFileObject = new JavaClassObject(className, kind);
                return javaFileObject;
            }

            public Class<?> toClass(DynamicClassLoader classLoader) {
                return javaFileObject.toClass(classLoader);
            }
        }

        private class JavaClassObject extends SimpleJavaFileObject {
            private final ByteArrayOutputStream bos = new ByteArrayOutputStream();
            private final String name;

            protected JavaClassObject(String name, Kind kind) {
                super(URI.create("string:///" + name.replace('.', '/')
                        + kind.extension), kind);
                this.name = name;
            }

            @Override
            public OutputStream openOutputStream() throws IOException {
                return bos;
            }

            Class<?> toClass(DynamicClassLoader dynamicClassLoader) {
                return dynamicClassLoader.defineClass(name, bos.toByteArray());
            }
        }

        private class DynamicClassLoader extends URLClassLoader {

            public DynamicClassLoader(ClassLoader parent) {
                super(new URL[0], parent);
            }

            @Override
            public Class<?> findClass(String name) throws ClassNotFoundException {
                return super.findClass(name);
            }

            public Class<?> defineClass(String name, byte[] bytes) {
                return defineClass(name, bytes, 0, bytes.length);
            }
        }
    }

    /**
     * @author qbit
     */
    public static class SrcGenerator {

        public static final String PACKAGE_PREFIX = "mock.";

        public static void main(String[] args) {
            System.out.println(new SrcGenerator().implSrc(EnterprisesController.class));
        }

        private static final String NEW_LINE = System.getProperty("line.separator", "\n");

        String implSrc(Class<?> controllerInterface) {
            String implClassName = implClassName(controllerInterface);
            StringBuilder src = new StringBuilder();
            srcPackage(src, implClassName);
            String simpleName = StringUtils.substringAfterLast(implClassName, ".");
            src.append('@').append(RestController.class.getName()).append(NEW_LINE);
            src.append("public class ").append(simpleName).append(" implements ").append(controllerInterface.getName()).append("{").append(NEW_LINE);
            Stream.of(controllerInterface.getMethods()).forEach(method -> this.srcMethod(src, method, controllerInterface));
            src.append("}");
            return src.toString();
        }

        private void srcMethod(StringBuilder src, Method method, Class<?> controllerInterface) {
            if (method.isDefault()) {
                return;
            }
            if (method.getReturnType() == List.class) {
                listMethod(src, method, controllerInterface);
                return;
            }
            src.append("public ").append(toString(GenericTypeResolver.resolveReturnType(method, controllerInterface)))
                    .append(" ").append(method.getName()).append("(");
            srcParameters(src, method);
            src.append("){");
            if (void.class != method.getReturnType()) {
                src.append("return ").append(MockUtils.class.getName()).append(".mock(").append(toString(method.getReturnType())).append(".class);").append(NEW_LINE);
            }
            src.append("}").append(NEW_LINE);
        }

        private void srcParameters(StringBuilder src, Method method) {
            final int length = method.getParameterCount();
            for (int i = 0; i < length; i++) {
                src.append(NEW_LINE);
                srcParameter(src, new MethodParameter(method, i));
                if (length - 1 > i) {
                    src.append(',');
                }
            }
        }

        private String toString(Type clazz) {
            return clazz.getTypeName().replace('$', '.');
        }

        private void listMethod(StringBuilder src, Method method, Class<?> controllerInterface) {
            final String className = toString(ResolvableType.forMethodReturnType(method, controllerInterface).getNested(2).getType());
            src.append("public java.util.List<").append(className).append("> ").append(method.getName()).append("(");
            srcParameters(src, method);
            src.append("){");
            src.append("return ").append(MockUtils.class.getName()).append(".mockList(").append(className).append(".class);").append(NEW_LINE);
            src.append("}").append(NEW_LINE);
        }

        private void srcParameter(StringBuilder src, MethodParameter methodParameter) {
            Stream.of(methodParameter.getParameterAnnotations()).forEach(annotation -> srcAnnotation(src, annotation));
            src.append(ResolvableType.forMethodParameter(methodParameter).toString().replace('$', '.')).append(" ").append(getParameterName(methodParameter));
        }

        private String getParameterName(MethodParameter methodParameter) {
            if (null != methodParameter.getParameterName()) {
                return methodParameter.getParameterName();
            } else {
                var pathVariable = methodParameter.getParameterAnnotation(PathVariable.class);
                if (null != pathVariable) {
                    String id = pathVariable.value();
                    if (StringUtils.isEmpty(id)) {
                        id = "id";
                    }
                    return id;
                }
                var parameterType = methodParameter.getParameter().getType();
                if (parameterType == List.class) {
                    return getParameterNameByClass(ResolvableType.forMethodParameter(methodParameter).getNested(1)) + "s";
                } else {
                    return getParameterNameByClass(ResolvableType.forMethodParameter(methodParameter));
                }
            }
        }

        private String getParameterNameByClass(ResolvableType methodParameter) {
            String className = methodParameter.getType().getTypeName();
            for (String postfix : new String[]{"Query", "DTO", "TO", "Command", "Event", "Message"}) {
                if (className.endsWith(postfix)) {
                    return postfix.toLowerCase();
                }
            }
            return "arg";
        }

        private void srcAnnotation(StringBuilder src, Annotation annotation) {
            if (annotation instanceof PathVariable) {
                PathVariable pathVariable = (PathVariable) annotation;
                src.append("@")
                        .append(annotation.annotationType().getName())
                        .append('(');
                src.append('"').append(pathVariable.value()).append('"');
                src.append(')');
            } else {
                src.append(annotation).append(" ");

            }

        }

        private void srcPackage(StringBuilder src, String implClassName) {
            src.append("package ")
                    .append(StringUtils.substringBeforeLast(implClassName, "."))
                    .append(';')
                    .append(NEW_LINE);
        }

        public String implClassName(Class<?> controllerInterface) {
            String interfaceName = controllerInterface.getSimpleName();
            for (String postfix : ControllerScanner.controllerPostfix) {
                if (interfaceName.endsWith(postfix)) {
                    return PACKAGE_PREFIX + controllerInterface.getPackage().getName() + '.' +
                            interfaceName.substring(0, interfaceName.length() - postfix.length());
                }
            }
            throw new IllegalArgumentException("can not generate impl class name for " + controllerInterface);
        }
    }

    public static class MockUtils {
        public static <T> T mock(Class<T> clazz) {
            return null;
        }

        public static <T> List<T> mockList(Class<T> clazz) {
            return null;
        }
    }
}

也需要修改扫描的swagger的包名路径


import com.github.xiaoymin.knife4j.spring.annotations.EnableKnife4j;
import com.google.common.collect.Lists;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import springfox.documentation.builders.PathSelectors;
import springfox.documentation.builders.RequestHandlerSelectors;
import springfox.documentation.service.ApiInfo;
import springfox.documentation.service.Contact;
import springfox.documentation.spi.DocumentationType;
import springfox.documentation.spring.web.plugins.Docket;
import springfox.documentation.swagger2.annotations.EnableSwagger2WebMvc;

@EnableSwagger2WebMvc
@SpringBootApplication
@Configuration
@EnableKnife4j
public class DocApplication {

    public static void main(String[] args) {
        SpringApplication.run(DocApplication.class, args);
    }

    @Bean
    public Docket docket(
            @Value("${spring.application.name}") String title,
            @Value("${spring.application.description}")String description,
            @Value("${spring.application.version}") String version) {
        return new Docket(DocumentationType.SWAGGER_2)
                .apiInfo(new ApiInfo(title,
                        description,
                        version,
                        "",
                        new Contact("Qbit", "", "zhang.jun01@redstarclouds.com"),
                        "", "", Lists.newArrayList()))
                .select()
                .apis(RequestHandlerSelectors.basePackage(ControllerRegistry.SrcGenerator.PACKAGE_PREFIX))
                .paths(PathSelectors.any())
                .build();
    }
}

pom.xml大概如下

<artifactId>doc</artifactId>
<dependencies>
    <dependency>
        <groupId>${parent.groupId}</groupId>
        <artifactId>api</artifactId>
        <version>${parent.version}</version>
    </dependency>
    <dependency>
        <groupId>com.pig4cloud.plugin</groupId>
        <artifactId>knife4j-spring-ui</artifactId>
        <version>3.0.2</version>
    </dependency>
    <dependency>
        <groupId>com.github.xiaoymin</groupId>
        <artifactId>knife4j-micro-spring-boot-starter</artifactId>
        <version>3.0.2</version>
    </dependency>
    <dependency>
        <groupId>org.apache.commons</groupId>
        <artifactId>commons-lang3</artifactId>
    </dependency>
</dependencies>
<build>
    <resources>
        <resource>
            <filtering>true</filtering>
            <directory>${basedir}/src/main/resources</directory>
            <includes>
                <include>**/*.*</include>
            </includes>
        </resource>
    </resources>
    <plugins>
        <plugin>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-maven-plugin</artifactId>
            <executions>
                <execution>
                    <goals>
                        <goal>repackage</goal>
                    </goals>
                </execution>
            </executions>
            <configuration>
                <layers>
                    <enabled>true</enabled>
                </layers>
            </configuration>
        </plugin>
        <plugin>
            <groupId>com.spotify</groupId>
            <artifactId>dockerfile-maven-plugin</artifactId>
            <version>1.4.13</version>
            <configuration>
                <repository>${parent.artifactId}</repository>
                <buildArgs>
                    <PROJECT_NAME>${parent.artifactId}</PROJECT_NAME>
                    <JAR_FILE>target/${project.build.finalName}.jar</JAR_FILE>
                </buildArgs>
                <skip>false</skip>
            </configuration>
        </plugin>
    </plugins>
</build>