真香警告!扩展 swagger支持文档自动列举所有枚举值

3,013 阅读3分钟

承接上篇文章 《一站式解决使用枚举的各种痛点》 文章最后提到:在使用 swagger 来编写接口文档时,需要告诉前端枚举类型有哪些取值,每次增加取值之后,不仅要改代码,还要找到对应的取值在哪里使用了,然后修改 swagger 文档。反正小黑我觉得这样做很不爽,那有没有什么办法可以让 swagger 框架来帮我们自动列举出所有的枚举数值呢?

这期小黑同学就来讲讲解决方案。

先来看一下效果,有一个感性的认识

swagger

请注意哦,这里是课程类型不是我们手动列举出来的,是swagger框架帮我们自动列举的。对应的代码如下:

代码

那么,这是怎么做到的呢?

简单描述一下实现:

1、自定义 SwaggerDisplayEnum 注解,注解中有两个属性,这两个属性是用来干什么的呢?小黑我先不说,大家往下阅读,相信就能明白啦~

@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
public @interface SwaggerDisplayEnum {
    String index() default "index";

    String name() default "name";

}

2、在我们的自定义枚举类中标记 @SwaggerDisplayEnum 注解

@Getter
@AllArgsConstructor
@SwaggerDisplayEnum(index = "type", name = "desc")
public enum CourseType {

    /**
     * 图文
     */
    PICTURE(102, "图文"),
    /**
     * 音频
     */
    AUDIO(103, "音频"),
    /**
     * 视频
     */
    VIDEO(104, "视频"),
    /**
     * 外链
     */
    URL(105, "外链"),
    ;

    @JsonValue
    private final int type;
    private final String desc;

    private static final Map<Integer, CourseType> mappings;

    static {
        Map<Integer, CourseType> temp = new HashMap<>();
        for (CourseType courseType : values()) {
            temp.put(courseType.type, courseType);
        }
        mappings = Collections.unmodifiableMap(temp);
    }

    @EnumConvertMethod
    @JsonCreator(mode = JsonCreator.Mode.DELEGATING)
    @Nullable
    public static CourseType resolve(int index) {
        return mappings.get(index);
    }

}

3、实现 ModelPropertyBuilderPlugin 接口,扩展 swagger,实现在文档中列举所有的枚举值。

public class EnumModelPropertyBuilderPlugin implements ModelPropertyBuilderPlugin {

    @Override
    public void apply(ModelPropertyContext context) {
        Optional<BeanPropertyDefinition> optional = context.getBeanPropertyDefinition();
        if (!optional.isPresent()) {
            return;
        }

        final Class<?> fieldType = optional.get().getField().getRawType();

        addDescForEnum(context, fieldType);
    }

    @Override
    public boolean supports(DocumentationType delimiter) {
        return true;
    }

    private void addDescForEnum(ModelPropertyContext context, Class<?> fieldType) {
        if (Enum.class.isAssignableFrom(fieldType)) {
            SwaggerDisplayEnum annotation = AnnotationUtils.findAnnotation(fieldType, SwaggerDisplayEnum.class);
            if (annotation != null) {
                String index = annotation.index();
                String name = annotation.name();

                Object[] enumConstants = fieldType.getEnumConstants();

                List<String> displayValues =
                        Arrays.stream(enumConstants)
                                .filter(Objects::nonNull)
                                .map(item -> {
                                    Class<?> currentClass = item.getClass();

                                    Field indexField = ReflectionUtils.findField(currentClass, index);
                                    ReflectionUtils.makeAccessible(indexField);
                                    Object value = ReflectionUtils.getField(indexField, item);

                                    Field descField = ReflectionUtils.findField(currentClass, name);
                                    ReflectionUtils.makeAccessible(descField);
                                    Object desc = ReflectionUtils.getField(descField, item);
                                    return value + ":" + desc;

                                }).collect(Collectors.toList());


                ModelPropertyBuilder builder = context.getBuilder();
                Field descField = ReflectionUtils.findField(builder.getClass(), "description");
                ReflectionUtils.makeAccessible(descField);
                String joinText = ReflectionUtils.getField(descField, builder)
                        + " (" + String.join("; ", displayValues) + ")";

                builder.description(joinText).type(context.getResolver().resolve(Integer.class));
            }
        }

    }
}

4、实现 ParameterBuilderPluginOperationBuilderPlugin 接口,列举枚举参数的所有取值。

public class EnumParameterBuilderPlugin implements ParameterBuilderPlugin, OperationBuilderPlugin {

    private static final Joiner joiner = Joiner.on(",");

    @Override
    public void apply(ParameterContext context) {
        Class<?> type = context.resolvedMethodParameter().getParameterType().getErasedType();
        if (Enum.class.isAssignableFrom(type)) {
            SwaggerDisplayEnum annotation = AnnotationUtils.findAnnotation(type, SwaggerDisplayEnum.class);
            if (annotation != null) {

                String index = annotation.index();
                String name = annotation.name();
                Object[] enumConstants = type.getEnumConstants();
                List<String> displayValues = Arrays.stream(enumConstants).filter(Objects::nonNull).map(item -> {
                    Class<?> currentClass = item.getClass();

                    Field indexField = ReflectionUtils.findField(currentClass, index);
                    ReflectionUtils.makeAccessible(indexField);
                    Object value = ReflectionUtils.getField(indexField, item);

                    Field descField = ReflectionUtils.findField(currentClass, name);
                    ReflectionUtils.makeAccessible(descField);
                    Object desc = ReflectionUtils.getField(descField, item);
                    return value.toString();

                }).collect(Collectors.toList());

                ParameterBuilder parameterBuilder = context.parameterBuilder();
                AllowableListValues values = new AllowableListValues(displayValues, "LIST");
                parameterBuilder.allowableValues(values);
            }
        }
    }


    @Override
    public boolean supports(DocumentationType delimiter) {
        return true;
    }

    @Override
    public void apply(OperationContext context) {
        Map<String, List<String>> map = new HashMap<>();
        List<ResolvedMethodParameter> parameters = context.getParameters();
        parameters.forEach(parameter -> {
            ResolvedType parameterType = parameter.getParameterType();
            Class<?> clazz = parameterType.getErasedType();
            if (Enum.class.isAssignableFrom(clazz)) {
                SwaggerDisplayEnum annotation = AnnotationUtils.findAnnotation(clazz, SwaggerDisplayEnum.class);
                if (annotation != null) {
                    String index = annotation.index();
                    String name = annotation.name();
                    Object[] enumConstants = clazz.getEnumConstants();

                    List<String> displayValues = Arrays.stream(enumConstants).filter(Objects::nonNull).map(item -> {
                        Class<?> currentClass = item.getClass();

                        Field indexField = ReflectionUtils.findField(currentClass, index);
                        ReflectionUtils.makeAccessible(indexField);
                        Object value = ReflectionUtils.getField(indexField, item);

                        Field descField = ReflectionUtils.findField(currentClass, name);
                        ReflectionUtils.makeAccessible(descField);
                        Object desc = ReflectionUtils.getField(descField, item);
                        return value + ":" + desc;

                    }).collect(Collectors.toList());

                    map.put(parameter.defaultName().or(""), displayValues);

                    OperationBuilder operationBuilder = context.operationBuilder();
                    Field parametersField = ReflectionUtils.findField(operationBuilder.getClass(), "parameters");
                    ReflectionUtils.makeAccessible(parametersField);
                    List<Parameter> list = (List<Parameter>) ReflectionUtils.getField(parametersField, operationBuilder);

                    map.forEach((k, v) -> {
                        for (Parameter currentParameter : list) {
                            if (StringUtils.equals(currentParameter.getName(), k)) {
                                Field description = ReflectionUtils.findField(currentParameter.getClass(), "description");
                                ReflectionUtils.makeAccessible(description);
                                Object field = ReflectionUtils.getField(description, currentParameter);
                                ReflectionUtils.setField(description, currentParameter, field + " , " + joiner.join(v));
                                break;
                            }
                        }
                    });
                }
            }
        });
    }
}

这篇文章比较枯燥,小黑我也不知道该怎么去讲述,只是将源码附录了出来。如果有读者看了之后还是不清楚的话,可以给我留言,我会一一解答。感谢你的阅读~~

相关源码已经上传到了 github:github.com/shenjianeng…