背景介绍
最近在使用swagger knife4j这个在线api文档工具的过程中,遇到了两个问题。
- 接口根据不同的入参返回不同类型对象,希望在api文档中展示出所有响应结果的文档
- 接口的请求参数对象中,部分属性是由后端填充的,所以希望在接口api文档中把这些入参隐藏掉。
为了解决上面提到的两个问题,我的第一反应是去网上查资料。但是找了一圈,也没有找到现成的解决方法。故而通过参考knife4j的某些扩展方式自己实现,并把自己的实现方法分享出来,以便有类似需求的小伙儿参考。
第2个问题其实knife4j的@ApiOperationSupport注解也能实现隐藏请求参数的效果,但是呢它有一个缺陷(或者说是bug),那就是一旦给方法加上该注解之后,如果请求参数中存在嵌套的对象,那么文档中生成的请求示例将会缺少部分参数。
实现思路
通过springfox的扩展机制,利用javassist动态创建类的方式来替换原接口的响应参数或请求参数对应的类 用到的扩展点如下:
// 替换响应结果中的类
springfox.documentation.spi.service.OperationBuilderPlugin
// 替换请求参数中的类
springfox.documentation.spi.service.ParameterBuilderPlugin
// 将javassist动态生成的类解析为Model,以便在knife4j页面中可以正常展示api文档
springfox.documentation.spi.service.OperationModelsProviderPlugin
利用以下两种注解提供接口的响应类型,兼容v2/v3
io.swagger.v3.oas.annotations.responses.ApiResponse
io.swagger.v3.oas.annotations.responses.ApiResponses
io.swagger.annotations.ApiResponses
io.swagger.annotations.ApiResponse
利用自定义注解来提供需要隐藏的参数信息
效果展示
响应结果只能通过将状态码定义为类似200-1这种方式,因为所有状态都定义为200,则只会展示一种结果,就达不到目的了。当然这也算是这种解决方式的一个缺陷吧,但总归时实现了将所有响应类型的属性文档均展示出来了,更加方便对接的同学开发使用
代码实现
工具类
import java.util.Set;
import java.util.function.Consumer;
import javassist.ClassClassPath;
import javassist.ClassPool;
import javassist.CtClass;
import javassist.CtField;
import javassist.LoaderClassPath;
import javassist.Modifier;
import javassist.NotFoundException;
import javassist.bytecode.AnnotationsAttribute;
import javassist.bytecode.ClassFile;
import javassist.bytecode.ConstPool;
import javassist.bytecode.FieldInfo;
import javassist.bytecode.annotation.Annotation;
import javassist.bytecode.annotation.AnnotationMemberValue;
import javassist.bytecode.annotation.ArrayMemberValue;
import javassist.bytecode.annotation.BooleanMemberValue;
import javassist.bytecode.annotation.ByteMemberValue;
import javassist.bytecode.annotation.CharMemberValue;
import javassist.bytecode.annotation.ClassMemberValue;
import javassist.bytecode.annotation.DoubleMemberValue;
import javassist.bytecode.annotation.EnumMemberValue;
import javassist.bytecode.annotation.FloatMemberValue;
import javassist.bytecode.annotation.IntegerMemberValue;
import javassist.bytecode.annotation.LongMemberValue;
import javassist.bytecode.annotation.MemberValueVisitor;
import javassist.bytecode.annotation.ShortMemberValue;
import javassist.bytecode.annotation.StringMemberValue;
import lombok.extern.slf4j.Slf4j;
/**
* javassist工具类
*/
@Slf4j
public class JavassistUtil {
private static volatile ClassPool classPool;
/**通过追加当前类加载器,解决springboot项目打包之后javassist提示找不到类的问题
* @return
*/
public static ClassPool getClassPool() {
if(classPool!=null) {
return classPool;
}
synchronized (JavassistUtil.class) {
if(classPool==null) {
ClassPool cp = ClassPool.getDefault();
cp.appendClassPath(new LoaderClassPath(JavassistUtil.class.getClassLoader()));
classPool = cp;
}
}
return classPool;
}
/**获取属性类型
* @param propetyType
* @return
*/
public static CtClass getFieldType(Class<?> propetyType) {
CtClass fieldType= null;
try{
if (!propetyType.isAssignableFrom(Void.class)){
fieldType=classPool.get(propetyType.getName());
}else{
fieldType=classPool.get(String.class.getName());
}
}catch (NotFoundException e){
//抛异常
ClassClassPath path=new ClassClassPath(propetyType);
classPool.insertClassPath(path);
try {
fieldType=classPool.get(propetyType.getName());
} catch (NotFoundException e1) {
log.error(e1.getMessage(),e1);
//can't find
}
}
return fieldType;
}
/**获取属性的指定注解
* @param originClass
* @param fieldName
* @param annType
* @return
* @throws Exception
*/
public static Annotation getFieldAnnotation(Class<?> originClass, String fieldName, Class<? extends java.lang.annotation.Annotation> annType) throws Exception {
CtClass sourceCtClass = JavassistUtil.getClassPool().get(originClass.getName());
CtField oriCf = sourceCtClass.getField(fieldName);
FieldInfo oriFi = oriCf.getFieldInfo();
AnnotationsAttribute annAttr = (AnnotationsAttribute) oriFi.getAttribute(AnnotationsAttribute.visibleTag);
return annAttr.getAnnotation(annType.getName());
}
/**获取类的指定注解
* @param originClass
* @param annType
* @return
* @throws Exception
*/
public static Annotation getClassAnnotation(Class<?> originClass, Class<? extends java.lang.annotation.Annotation> annType) throws Exception {
CtClass sourceCtClass = JavassistUtil.getClassPool().get(originClass.getName());
ClassFile classFile = sourceCtClass.getClassFile();
AnnotationsAttribute annAttr = (AnnotationsAttribute) classFile.getAttribute(AnnotationsAttribute.visibleTag);
return annAttr.getAnnotation(annType.getName());
}
/**向目标类中插入属性
* @param fieldName 属性名称
* @param fieldType 属性类型
* @param targetCtClass 目标类
* @param originAnns 原始注解
* @throws Exception
*/
public static void addField(String fieldName, Class<?> fieldType, CtClass targetCtClass, Annotation... originAnns) throws Exception {
CtField field = new CtField(getFieldType(fieldType), fieldName, targetCtClass);
field.setModifiers(Modifier.PUBLIC);
addFieldAnnotation(field, originAnns);
targetCtClass.addField(field);
}
/**给属性添加注解
* @param field
* @param originAnns
*/
public static void addFieldAnnotation(CtField field, Annotation... originAnns) {
if(originAnns!=null && originAnns.length>0) {
ConstPool targetCp = field.getFieldInfo().getConstPool();
AnnotationsAttribute attr = new AnnotationsAttribute(targetCp, AnnotationsAttribute.visibleTag);
for(Annotation originAnn : originAnns) {
attr.addAnnotation(copyAnnotation(originAnn, targetCp));
}
field.getFieldInfo().addAttribute(attr);
}
}
/** 向类添加注解
* @param targetCtClass
* @param modifier
* @param originAnns
*/
public static void addClassAnnotation(CtClass targetCtClass, Consumer<Annotation> modifier, Annotation... originAnns) {
if(originAnns==null || originAnns.length==0) {
return;
}
AnnotationsAttribute attr = new AnnotationsAttribute(targetCtClass.getClassFile().getConstPool(), AnnotationsAttribute.visibleTag);
for (Annotation oriAnnotation : originAnns) {
Annotation copyAnnotation = JavassistUtil.copyAnnotation(oriAnnotation, targetCtClass);
modifier.accept(copyAnnotation);
attr.addAnnotation(copyAnnotation);
}
targetCtClass.getClassFile().addAttribute(attr);
}
/**复制注解信息
* @param originAnn
* @param targetCtClass
* @return
*/
public static Annotation copyAnnotation(Annotation originAnn, CtClass targetCtClass) {
ConstPool targetCp = targetCtClass.getClassFile().getConstPool();
return copyAnnotation(originAnn, targetCp);
}
/**复制注解信息
* @param originAnn
* @param targetCp
* @return
*/
public static Annotation copyAnnotation(Annotation originAnn, ConstPool targetCp) {
Annotation ann = new Annotation(originAnn.getTypeName(), targetCp);
Set<String> memberNames = originAnn.getMemberNames();
for (String memberName : memberNames) {
originAnn.getMemberValue(memberName).accept(getVisitor(targetCp, ann, memberName));
}
return ann;
}
private static MemberValueVisitor getVisitor(ConstPool cp, Annotation ann, String memberName) {
return new MemberValueVisitor() {
@Override
public void visitStringMemberValue(StringMemberValue node) {
ann.addMemberValue(memberName, new StringMemberValue(node.getValue(), cp));
}
@Override
public void visitShortMemberValue(ShortMemberValue node) {
ann.addMemberValue(memberName, new ShortMemberValue(node.getValue(), cp));
}
@Override
public void visitLongMemberValue(LongMemberValue node) {
ann.addMemberValue(memberName, new LongMemberValue(node.getValue(), cp));
}
@Override
public void visitIntegerMemberValue(IntegerMemberValue node) {
ann.addMemberValue(memberName, new IntegerMemberValue(cp, node.getValue()));
}
@Override
public void visitFloatMemberValue(FloatMemberValue node) {
ann.addMemberValue(memberName, new FloatMemberValue(node.getValue(), cp));
}
@Override
public void visitEnumMemberValue(EnumMemberValue node) {
EnumMemberValue emv = new EnumMemberValue(cp);
emv.setType(node.getType());
emv.setValue(node.getValue());
ann.addMemberValue(memberName, emv);
}
@Override
public void visitDoubleMemberValue(DoubleMemberValue node) {
ann.addMemberValue(memberName, new DoubleMemberValue(node.getValue(), cp));
}
@Override
public void visitClassMemberValue(ClassMemberValue node) {
ann.addMemberValue(memberName, new ClassMemberValue(node.getValue(), cp));
}
@Override
public void visitCharMemberValue(CharMemberValue node) {
ann.addMemberValue(memberName, new CharMemberValue(node.getValue(), cp));
}
@Override
public void visitByteMemberValue(ByteMemberValue node) {
ann.addMemberValue(memberName, new ByteMemberValue(node.getValue(), cp));
}
@Override
public void visitBooleanMemberValue(BooleanMemberValue node) {
ann.addMemberValue(memberName, new BooleanMemberValue(node.getValue(), cp));
}
@Override
public void visitArrayMemberValue(ArrayMemberValue node) {
ArrayMemberValue amv = new ArrayMemberValue(cp);
amv.setValue(node.getValue());
ann.addMemberValue(memberName, amv);
}
@Override
public void visitAnnotationMemberValue(AnnotationMemberValue node) {
ann.addMemberValue(memberName, new AnnotationMemberValue(node.getValue(), cp));
}
};
}
}
展示多种响应结果
统一响应类
@Data
@AllArgsConstructor
@NoArgsConstructor
public class CommonResult<T> {
private Integer code;
private String msg;
private T data;
}
替换文档响应结果的类型
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.Function;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.core.annotation.Order;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Component;
import com.fasterxml.classmate.ResolvedType;
import com.fasterxml.classmate.TypeResolver;
import com.github.xiaoymin.knife4j.spring.util.ByteUtils;
import cn.hutool.core.collection.CollUtil;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import springfox.documentation.builders.RepresentationBuilder;
import springfox.documentation.builders.ResponseBuilder;
import springfox.documentation.schema.ModelSpecification;
import springfox.documentation.schema.ModelSpecificationProvider;
import springfox.documentation.schema.ResolvedTypes;
import springfox.documentation.schema.TypeNameExtractor;
import springfox.documentation.schema.plugins.SchemaPluginsManager;
import springfox.documentation.schema.property.PackageNames;
import springfox.documentation.service.Response;
import springfox.documentation.spi.DocumentationType;
import springfox.documentation.spi.schema.ViewProviderPlugin;
import springfox.documentation.spi.schema.contexts.ModelContext;
import springfox.documentation.spi.service.OperationBuilderPlugin;
import springfox.documentation.spi.service.contexts.OperationContext;
import springfox.documentation.swagger.common.SwaggerPluginSupport;
@ConditionalOnProperty(prefix = "swagger", name = "enable", havingValue = "true")
@Component
@Order(SwaggerPluginSupport.OAS_PLUGIN_ORDER+1)
public class ApiResponseSchemaModelReader implements OperationBuilderPlugin {
private final TypeNameExtractor typeNameExtractor;
private final SchemaPluginsManager pluginsManager;
private final ModelSpecificationProvider modelProvider;
@Autowired
private TypeResolver typeResolver;
@Autowired
public ApiResponseSchemaModelReader(TypeNameExtractor typeNameExtractor, SchemaPluginsManager pluginsManager,
@Qualifier("cachedModels") ModelSpecificationProvider modelProvider) {
super();
this.typeNameExtractor = typeNameExtractor;
this.pluginsManager = pluginsManager;
this.modelProvider = modelProvider;
}
@Override
public boolean supports(DocumentationType delimiter) {
return true;
}
@Override
public void apply(OperationContext operationContext) {
if(applyV3(operationContext)) {
return;
}
applyV2(operationContext);
}
/**
* 解析v3版本注解中的响应类
*/
private boolean applyV3(OperationContext context) {
Optional<ApiResponses> apiResponses = context.findAnnotation(ApiResponses.class);
List<ApiResponse> list = new ArrayList<>();
if (apiResponses.isPresent()) {
CollUtil.addAll(list, apiResponses.get().value());
} else {
Optional<ApiResponse> apiResponse = context.findAnnotation(ApiResponse.class);
if (apiResponse.isPresent()) {
list.add(apiResponse.get());
}
}
if (list.isEmpty()) {
return false;
}
for (ApiResponse apiResponse : list) {
replaceResponse(context, apiResponse);
}
return list.size()>0;
}
/**
* 解析v2版本注解中的响应类
*/
private boolean applyV2(OperationContext context) {
Optional<io.swagger.annotations.ApiResponses> apiResponses = context.findAnnotation(io.swagger.annotations.ApiResponses.class);
List<io.swagger.annotations.ApiResponse> list = new ArrayList<>();
if (apiResponses.isPresent()) {
CollUtil.addAll(list, apiResponses.get().value());
} else {
Optional<io.swagger.annotations.ApiResponse> apiResponse = context.findAnnotation(io.swagger.annotations.ApiResponse.class);
if (apiResponse.isPresent()) {
list.add(apiResponse.get());
}
}
if (list.isEmpty()) {
return false;
}
for (io.swagger.annotations.ApiResponse apiResponse : list) {
Class<?> clazz = apiResponse.response();
replace(context, clazz, String.valueOf(apiResponse.code()), apiResponse.message());
}
return list.size()>0;
}
private void replaceResponse(OperationContext operationContext, ApiResponse apiResponse) {
Content[] contents = apiResponse.content();
if (contents == null || contents.length == 0) {
return;
}
for (Content content : contents) {
Schema schema = content.schema();
Class<?> target = schema.implementation();
replace(operationContext, target, apiResponse.responseCode(), apiResponse.description());
}
}
private void replace(OperationContext operationContext, Class<?> target, String code, String description) {
if (target == Void.class || target == null || CommonResult.class.isAssignableFrom(target)) {
return;
}
ResponseBuilder responseBuilder = new ResponseBuilder();
//动态生成的类名,须保证唯一性,避免被覆盖
String clazzName = CommonResult.class.getName() + target.getSimpleName();
// 利用knife4j中的工具类动态创建新的Class
Class<?> loadClass = ByteUtils.load(clazzName);
if (loadClass == null) {
return;
}
ResolvedType returnType = operationContext.alternateFor(typeResolver.resolve(loadClass));
if (ResolvedTypes.isVoid(returnType)) {
return;
}
ModelContext modelContext = modelContext(operationContext, returnType);
Optional<ModelSpecification> modelSpecificationsFor = modelProvider.modelSpecificationsFor(modelContext);
// 创建api文档响应结果对应的Model
modelSpecificationsFor.ifPresent((mds -> {
String typeName = typeNameExtractor.typeName(ModelContext.fromParent(modelContext, returnType));
Function<Consumer<RepresentationBuilder>, ResponseBuilder> function = responseBuilder.representation(MediaType.ALL);
function.apply(builder -> {
builder.model(mdsBuilder -> {
mdsBuilder.name(mds.getName()).referenceModel(rm->{
rm.key(m -> m.qualifiedModelName(q -> q.namespace(PackageNames.safeGetPackageName(returnType)).name(typeName))
.viewDiscriminator(modelContext.getView().orElse(null))
.validationGroupDiscriminators(modelContext.getValidationGroups()).isResponse(modelContext.isReturnType()).build());
});
});
});
Response response = responseBuilder.code(code).description(description).build();
operationContext.operationBuilder().responses(Arrays.asList(response));
}));
}
private ModelContext modelContext(OperationContext operationContext, ResolvedType returnType) {
ViewProviderPlugin viewProvider = pluginsManager.viewProvider(operationContext.getDocumentationContext().getDocumentationType());
return operationContext.operationModelsBuilder().addReturn(returnType, viewProvider.viewFor(operationContext));
}
}
替换文档响应结果中的model
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import com.fasterxml.classmate.ResolvedType;
import com.fasterxml.classmate.TypeResolver;
import cn.hutool.core.collection.CollUtil;
import io.swagger.annotations.ApiModelProperty;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import javassist.CannotCompileException;
import javassist.ClassPool;
import javassist.CtClass;
import javassist.CtField;
import javassist.Modifier;
import javassist.NotFoundException;
import javassist.bytecode.AnnotationsAttribute;
import javassist.bytecode.ConstPool;
import javassist.bytecode.annotation.Annotation;
import javassist.bytecode.annotation.BooleanMemberValue;
import javassist.bytecode.annotation.StringMemberValue;
import lombok.extern.slf4j.Slf4j;
import springfox.documentation.spi.DocumentationType;
import springfox.documentation.spi.service.OperationModelsProviderPlugin;
import springfox.documentation.spi.service.contexts.RequestMappingContext;
@Slf4j
@ConditionalOnProperty(prefix = "swagger", name = "enable", havingValue = "true")
@Component
@Order(Ordered.HIGHEST_PRECEDENCE+14)
public class ApiResponseSchemaModelsProvider implements OperationModelsProviderPlugin {
@Autowired
private TypeResolver typeResolver;
private Map<Class<?>, ResolvedType> cache = new HashMap<>();
@Override
public boolean supports(DocumentationType delimiter) {
return true;
}
@Override
public void apply(RequestMappingContext context) {
Set<Class<?>> modelClazzes = findV3SchemaModels(context);
if(CollUtil.isEmpty(modelClazzes)) {
modelClazzes = findSchemaModels(context);
}
for (Class<?> clazz : modelClazzes) {
if(clazz==Void.class || clazz==null || CommonResult.class.isAssignableFrom(clazz)) {
continue;
}
ResolvedType modelType = cache.get(clazz);
if(modelType==null) {
Class<?> dynamicModelClass = createDynamicModelClass(clazz);
modelType=context.alternateFor(typeResolver.resolve(dynamicModelClass==null ?clazz : dynamicModelClass));
cache.put(clazz, modelType);
}
context.operationModelsBuilder().addReturn(modelType);
}
}
private Set<Class<?>> findV3SchemaModels(RequestMappingContext context) {
Optional<ApiResponses> apiResponses = context.findAnnotation(ApiResponses.class);
List<ApiResponse> list = new ArrayList<>();
if(apiResponses.isPresent()) {
CollUtil.addAll(list, apiResponses.get().value());
}else {
Optional<ApiResponse> apiResponse = context.findAnnotation(ApiResponse.class);
if(apiResponse.isPresent()) {
list.add(apiResponse.get());
}
}
if(list.isEmpty()) {
return Collections.emptySet();
}
Set<Class<?>> modelClazzes = new LinkedHashSet<>();
for (ApiResponse apiResponse : list) {
Content[] contents = apiResponse.content();
if(contents==null || contents.length==0) {
continue;
}
for (Content content : contents) {
Schema schema = content.schema();
modelClazzes.add(schema.implementation());
}
}
return modelClazzes;
}
private Set<Class<?>> findSchemaModels(RequestMappingContext context) {
Optional<io.swagger.annotations.ApiResponses> apiResponses = context.findAnnotation(io.swagger.annotations.ApiResponses.class);
List<io.swagger.annotations.ApiResponse> list = new ArrayList<>();
if(apiResponses.isPresent()) {
CollUtil.addAll(list, apiResponses.get().value());
}else {
Optional<io.swagger.annotations.ApiResponse> apiResponse = context.findAnnotation(io.swagger.annotations.ApiResponse.class);
if(apiResponse.isPresent()) {
list.add(apiResponse.get());
}
}
if(list.isEmpty()) {
return Collections.emptySet();
}
Set<Class<?>> modelClazzes = new LinkedHashSet<>();
for (io.swagger.annotations.ApiResponse apiResponse : list) {
Class<?> clazz = apiResponse.response();
modelClazzes.add(clazz);
}
return modelClazzes;
}
/**
* createModel
* @param name 类名
* @param parameters 属性集合
* @return 动态生成类
*/
public static Class<?> createDynamicModelClass(Class<?> target){
String clazzName= CommonResult.class.getName() + target.getSimpleName();
ClassPool classPool = JavassistUtil.getClassPool();
try {
CtClass tmp=classPool.getCtClass(clazzName);
if (tmp!=null){
tmp.detach();
}
} catch (NotFoundException e) {
}
CtClass ctClass=classPool.makeClass(clazzName);
try{
ctClass.addField(createField("code", int.class, "响应码", ctClass));
ctClass.addField(createField("msg", String.class, "响应消息", ctClass));
ctClass.addField(createField("data", target, "数据", ctClass));
return ctClass.toClass();
}catch (Throwable e){
log.error(e.getMessage());
}
return null;
}
private static CtField createField(String name, Class<?> type, String remark, CtClass ctClass) throws NotFoundException, CannotCompileException {
CtField field=new CtField(JavassistUtil.getFieldType(type),name,ctClass);
field.setModifiers(Modifier.PUBLIC);
ConstPool constPool=ctClass.getClassFile().getConstPool();
AnnotationsAttribute attr = new AnnotationsAttribute(constPool, AnnotationsAttribute.visibleTag);
Annotation ann = new Annotation(ApiModelProperty.class.getName(), constPool);
ann.addMemberValue("value", new StringMemberValue(remark, constPool));
ann.addMemberValue("required", new BooleanMemberValue(false,constPool));
attr.addAnnotation(ann);
field.getFieldInfo().addAttribute(attr);
return field;
}
}
隐藏输入参数
自定义注解
import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* 忽略某些请求参数
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface ApiIgnoreInParam {
/** 指定的类型
* @return
*/
Class<?> value() default Void.class;
/** 指定的类型的属性名称
* @return
*/
String[] names();
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface ApiIgnoreInParams {
ApiIgnoreInParam[] value();
}
}
替换请求参数类
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.collections4.keyvalue.MultiKey;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import com.fasterxml.classmate.ResolvedType;
import com.fasterxml.classmate.TypeResolver;
import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.util.ReflectUtil;
import cn.hutool.json.JSONUtil;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import javassist.CtClass;
import javassist.bytecode.annotation.Annotation;
import javassist.bytecode.annotation.StringMemberValue;
import lombok.extern.slf4j.Slf4j;
import springfox.documentation.service.ResolvedMethodParameter;
import springfox.documentation.spi.DocumentationType;
import springfox.documentation.spi.service.OperationModelsProviderPlugin;
import springfox.documentation.spi.service.contexts.RequestMappingContext;
@Slf4j
@ConditionalOnProperty(prefix = "swagger", name = "enable", havingValue = "true")
@Component
@Order(Ordered.HIGHEST_PRECEDENCE + 15)
public class ApiIgnoreSchemaModelsProvider implements OperationModelsProviderPlugin {
private static Map<MultiKey<String>, ResolvedType> cache = new HashMap<>();
private static AtomicInteger num = new AtomicInteger();
@Autowired
private TypeResolver typeResolver;
static ResolvedType getCachedResolvedType(Class<?> rawType, Set<String> names) {
List<String> sortedNames = new ArrayList<>(names);
Collections.sort(sortedNames);
MultiKey<String> key = new MultiKey<String>(rawType.getName(), JSONUtil.toJsonStr(sortedNames));
return cache.get(key);
}
@Override
public boolean supports(DocumentationType delimiter) {
return true;
}
@Override
public void apply(RequestMappingContext context) {
Map<Class, Set<String>> ignores = getIgnoreFieldNames(context);
if (ignores.isEmpty()) {
return;
}
List<ResolvedMethodParameter> parameters = context.getParameters();
for (ResolvedMethodParameter methodParameter : parameters) {
Class<?> originClass = methodParameter.getParameterType().getErasedType();
Set<String> nestedNames = ignores.get(originClass);
if (nestedNames == null) {
return;
}
List<String> sortedNames = new ArrayList<>(nestedNames);
Collections.sort(sortedNames);
MultiKey<String> key = new MultiKey<String>(originClass.getName(), JSONUtil.toJsonStr(sortedNames));
if (cache.containsKey(key)) {
continue;
}
// 忽略指定的属性
Field[] fields = ReflectUtil.getFields(originClass, f -> f.isAnnotationPresent(ApiModelProperty.class) && !nestedNames.contains(f.getName()));
if (fields == null || fields.length == 0) {
return;
}
Set<Class> processed = new HashSet<>();
Map<Class, List<Field>> validFieldMap = new HashMap<>();
filterField(originClass, ignores, validFieldMap, processed);
// 创建忽略指定属性的Model类
String targetClassName = getDynClassName(originClass);
Class<?> modelClass = createModelClass(targetClassName, originClass, validFieldMap);
if (modelClass == null) {
return;
}
ResolvedType resolvedType = context.alternateFor(typeResolver.resolve(modelClass));
context.operationModelsBuilder().addInputParam(resolvedType);
cache.put(key, resolvedType);
}
}
private Class<?> createModelClass(String targetClassName, Class<?> originClass, Map<Class, List<Field>> validFieldMap) {
List<Field> fields = validFieldMap.get(originClass);
if (fields == null) {
return null;
}
try {
CtClass targetCtClass = JavassistUtil.getClassPool().makeClass(targetClassName);
// 为动态创建的类添加@ApiModel注解
Annotation oriAnnotation = JavassistUtil.getClassAnnotation(originClass, ApiModel.class);
JavassistUtil.addClassAnnotation(targetCtClass, ca->{
// 使用唯一名称覆盖原来的Model名称,避免引用到原始Model
StringMemberValue smv = (StringMemberValue) ca.getMemberValue("value");
smv.setValue(smv.getValue() + "_" + num.incrementAndGet());
}, oriAnnotation);
// 添加类属性
for (Field field : fields) {
Class<?> ft = field.getType();
if (validFieldMap.containsKey(ft)) {
ft = createModelClass(getDynClassName(ft), ft, validFieldMap);
if (ft == null) {
continue;
}
}
Annotation originAnn = JavassistUtil.getFieldAnnotation(originClass, field.getName(), ApiModelProperty.class);
JavassistUtil.addField(field.getName(), ft, targetCtClass, originAnn);
}
return targetCtClass.toClass();
} catch (Exception e) {
log.error("动态隐藏接口参数出错", e);
return null;
}
}
private String getDynClassName(Class<?> clazz) {
return clazz.getName() + "_" + UUID.randomUUID().toString().replace('-', '_');
}
private void filterField(Class<?> originClass, Map<Class, Set<String>> ignores, Map<Class, List<Field>> validFieldMap, Set<Class> processed) {
if (!processed.add(originClass)) {
return;
}
Field[] fields = ReflectUtil.getFields(originClass, f -> f.isAnnotationPresent(ApiModelProperty.class));
if (fields == null || fields.length == 0) {
return;
}
List<Field> validFields = new ArrayList<>(fields.length);
validFieldMap.put(originClass, validFields);
Set<String> names = ignores.get(originClass);
for (Field field : fields) {
if (names != null && names.contains(field.getName())) {
continue;
}
validFields.add(field);
Class<?> ft = field.getType();
if (!ft.isAnnotationPresent(ApiModel.class)) {
continue;
}
filterField(ft, ignores, validFieldMap, processed);
}
}
private Map<Class, Set<String>> getIgnoreFieldNames(RequestMappingContext context) {
Optional<ApiIgnoreInParams> more = context.findAnnotation(ApiIgnoreInParams.class);
List<ApiIgnoreInParam> inParams = new ArrayList<>();
more.ifPresent(ps -> Arrays.stream(ps.value()).forEach(inParams::add));
Optional<ApiIgnoreInParam> one = context.findAnnotation(ApiIgnoreInParam.class);
one.ifPresent(inParams::add);
Map<Class, Set<String>> map = new HashMap<>(inParams.size(), 1);
for (ApiIgnoreInParam apiIgnoreInParam : inParams) {
Set<String> names = map.get(apiIgnoreInParam.value());
if (names == null) {
names = new HashSet<>();
map.put(apiIgnoreInParam.value(), names);
}
for (String name : apiIgnoreInParam.names()) {
if (CharSequenceUtil.isNotBlank(name)) {
names.add(name);
}
}
if (!names.isEmpty()) {
map.put(apiIgnoreInParam.value(), names);
}
}
return map;
}
}
替换文档中的请求参数Model
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import com.fasterxml.classmate.ResolvedType;
import cn.hutool.core.text.CharSequenceUtil;
import javassist.ClassPool;
import springfox.documentation.schema.ModelSpecification;
import springfox.documentation.schema.ModelSpecificationProvider;
import springfox.documentation.schema.plugins.SchemaPluginsManager;
import springfox.documentation.schema.property.ModelSpecificationFactory;
import springfox.documentation.service.ResolvedMethodParameter;
import springfox.documentation.spi.DocumentationType;
import springfox.documentation.spi.schema.ViewProviderPlugin;
import springfox.documentation.spi.schema.contexts.ModelContext;
import springfox.documentation.spi.service.ParameterBuilderPlugin;
import springfox.documentation.spi.service.contexts.OperationContext;
import springfox.documentation.spi.service.contexts.ParameterContext;
@SuppressWarnings("deprecation")
@Component
@Order(Ordered.HIGHEST_PRECEDENCE + 101)
public class ApiIgnoreParameterBuilderPlugin implements ParameterBuilderPlugin {
static final ClassPool classPool = ClassPool.getDefault();
@Autowired
private SchemaPluginsManager pluginsManager;
@Autowired
private ModelSpecificationFactory models;
@Autowired
@Qualifier("cachedModels")
private ModelSpecificationProvider modelProvider;
@Override
public boolean supports(DocumentationType delimiter) {
return true;
}
@Override
public void apply(ParameterContext parameterContext) {
Map<Class, Set<String>> ignores = getIgnoreFieldNames(parameterContext);
if (ignores.isEmpty()) {
return;
}
ResolvedMethodParameter methodParameter = parameterContext.resolvedMethodParameter();
Optional<String> defaultName = methodParameter.defaultName();
Set<String> names = ignores.get(Void.class);
// 未明确指定类型,则当作全局数据处理
if (names != null && defaultName.isPresent() && names.contains(defaultName.get())) {
parameterContext.parameterBuilder().hidden(true);
return;
}
Class<?> originClass = methodParameter.getParameterType().getErasedType();
Set<String> nestedNames = ignores.get(originClass);
if (nestedNames == null) {
return;
}
ResolvedType resolvedType = ApiIgnoreSchemaModelsProvider.getCachedResolvedType(originClass, nestedNames);
if(resolvedType==null) {
return;
}
// 替换参数对应的Model
ResolvedType parameterType = parameterContext.alternateFor(resolvedType);
ModelContext modelContext = modelContext(parameterContext, methodParameter, parameterType);
ModelSpecification parameterModel = models.create(modelContext, parameterType);
parameterContext.requestParameterBuilder().contentModel(parameterModel);
}
private Map<Class, Set<String>> getIgnoreFieldNames(ParameterContext parameterContext) {
OperationContext operationContext = parameterContext.getOperationContext();
Optional<ApiIgnoreInParams> more = operationContext.findAnnotation(ApiIgnoreInParams.class);
List<ApiIgnoreInParam> inParams = new ArrayList<>();
more.ifPresent(ps -> Arrays.stream(ps.value()).forEach(inParams::add));
Optional<ApiIgnoreInParam> one = operationContext.findAnnotation(ApiIgnoreInParam.class);
one.ifPresent(inParams::add);
Map<Class, Set<String>> map = new HashMap<>(inParams.size(), 1);
for (ApiIgnoreInParam apiIgnoreInParam : inParams) {
Set<String> names = map.get(apiIgnoreInParam.value());
if (names == null) {
names = new HashSet<>();
map.put(apiIgnoreInParam.value(), names);
}
for (String name : apiIgnoreInParam.names()) {
if (CharSequenceUtil.isNotBlank(name)) {
names.add(name);
}
}
if (!names.isEmpty()) {
map.put(apiIgnoreInParam.value(), names);
}
}
return map;
}
private ModelContext modelContext(ParameterContext context, ResolvedMethodParameter methodParameter, ResolvedType parameterType) {
ViewProviderPlugin viewProvider = pluginsManager.viewProvider(context.getDocumentationContext().getDocumentationType());
return context.getOperationContext().operationModelsBuilder().addInputParam(parameterType, viewProvider.viewFor(methodParameter), new HashSet<>());
}
}
应用示例
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import io.swagger.annotations.Api;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
@Api(value = "演示接口", tags = "演示接口")
@RestController
@RequestMapping("/demo")
public class DemoController {
@Operation(summary = "多种类型响应结果", description = "多种类型响应结果")
@ApiResponses({
@ApiResponse(responseCode = "200-1", description = "响应结果详情1", content = @Content(schema = @Schema(implementation = ResponseResult1.class)))
,@ApiResponse(responseCode = "200-2", description = "响应结果详情2", content = @Content(schema = @Schema(implementation = ResponseResult2.class)))
})
@GetMapping("/multiTypeResponse")
public CommonResult<Object> multiTypeResponse() {
return CommonResult.success(null);
}
//若要隐藏请求参数中某个类型的某些属性,那么需使用@ApiIgnoreInParams
//@ApiIgnoreInParams({
//@ApiIgnoreInParam(value = RequestParams.class, names = {"param1"}),
//@ApiIgnoreInParam(value = Nested.class, names = {"field1"}),
//})
//上面的用法,将会隐藏RequestParams类型的param1属性,以及Nested类型的field1属性
//不区分RequestParams嵌套Nested,还是Nested嵌套RequestParams
@ApiIgnoreInParam(value = RequestParams.class, names = {"param1", "param2", "param3"})
@Operation(summary = "隐藏指定参数", description = "隐藏指定参数")
@PostMapping("/hideRequestParams")
public CommonResult<Boolean> hideRequestParams(@RequestBody RequestParams param) {
return CommonResult.success(true);
}
}
@ApiModel("响应结果1")
@Getter
@Setter
public class ResponseResult1 {
@ApiModelProperty("id")
private String id;
@ApiModelProperty("名称")
private String name;
}
@ApiModel("响应结果2")
@Getter
@Setter
public class ResponseResult2 {
@ApiModelProperty("姓名")
private String name;
@ApiModelProperty("年龄")
private int age;
}
@ApiModel("请求参数")
@Getter
@Setter
public class RequestParams {
@ApiModelProperty("参数1")
private String param1;
@ApiModelProperty("参数2")
private String param2;
@ApiModelProperty("参数3")
private String param3;
@ApiModelProperty("姓名")
private String name;
@ApiModelProperty("地址")
private String address;
}