JDK动态代理实现HTTP统一管理(模仿controller的模)

853 阅读4分钟

实现思路

最近做的项目中,并没有使用分布式RPC框架,各个服务之间的通信走HTTP,通过nginx配置的负载均衡分发到各个实例以达到分布式通信的目的。

各个服务之间的调度走http造成了代码中存在大量的RestTemplate调用,为了避免后续开发过程中难以管理,于是我基于JDK的动态代理简单实现了一个RestInterface,通过类似于SpringMVC的方式来管理所有的HTTP请求。

步骤1:创建几个注解

GETPOST,一看就很像是springmvc的PostMappingGetMapping

package com.tydic.config.restconfig;

import java.lang.annotation.*;

/**
 * @author Gmw
 */
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
@Inherited
@Documented
public @interface GET {

    String path() default "";
}

package com.tydic.config.restconfig;

import java.lang.annotation.*;

/**
 * @author Gmw
 */
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
@Inherited
@Documented
public @interface POST {

    String path() default "";
}

这两个注解负责标注在RestInterface接口方法上,表示该方法使用POST方法还是GET方法来调用。

然后是请求体注解,偷懒所以我直接使用了RequestBody注解,没啥大区别,其实所有注解都可以用spring的。

然后是GET请求的参数注解,这回没有偷懒,自己创建了一个PARAM

package com.tydic.config.restconfig;

import java.lang.annotation.*;

/**
 * RequestParam
 * @author Gmw
 */
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.PARAMETER)
@Inherited
@Documented
public @interface PARAM {

    String value();
}

步骤2:HTTP调用接口RestInterface

/**
 * 统一管理代码中的HTTP请求
 * <p>
 * <b>注意:</b>
 * GET请求只能携带路径参数;POST请求只能携带请求体(v1.0)
 * </p>
 *
 * @author Gmw
 */
public interface RestInterface {

    /**
     * 通过工单ID获取用户信息
     *
     * @param parameter 请求体
     * @return 返回体
     */
    @POST(path = "/wm/api/sysmgr/users/info")
    JSONObject queryUserInfo(@RequestBody JSONObject parameter);

    /**
     * 获取订单更新
     *
     * @param parameter 请求体
     * @return 返回体
     */
    @POST(path = "/wm/api/worksheets/receipts/order/updates")
    JSONObject queryOrderUpdates(@RequestBody JSONObject parameter);

    /**
     * 提交订单
     *
     * @param parameter 请求体
     * @return 返回体
     */
    @POST(path = "/wm/api/orders/submit")
    JSONObject orderSubmit(@RequestBody JSONObject parameter);
}

接口很简单,通过注解的标注,以及大量的注释,就可以很清晰地知道每个方法调用的作用。

步骤3:实现动态代理类

动态代理的实现很简单,实现InvocationHandler接口即可:

    private static class RestInterfaceImpl implements InvocationHandler {

        @Override
        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
            // ...
        }
    }

其中三个参数主要使用到methodargs,代表当前代理的接口方法以及方法入参实例。

所有的与RestInterface接口类似的HTTP接口都可以使用这个RestInterfaceImpl作为代理实现,因为每个Rest接口的代理逻辑都是一样的,解析注解、组装请求参数、发送http请求并返回。

接下来通过Spring的配置类来提供这个代理类的实现,并封装为Spring组件:

@Configuration
public class RestInterfaceConfig {

    private final RestInterfaceImpl proxyImpl;

    public RestInterfaceConfig(
            @Value("${nginx.protocol}") String protocol,
            @Value("${nginx.host}") String host,
            @Value("${nginx.port}") String port,
            LogServiceImpl logService) {
        restInterface = new RestInterfaceImpl(protocol, host, port, logService);
    }

    @Bean(name = "restInterface")
    public RestInterface createRestInterface() {
        return (RestInterface) Proxy.newProxyInstance(
                // 用于加载代理类的加载器,用接口的就好
                RestInterface.class.getClassLoader(),
                // 定义该代理类要实现几个接口
                new Class[]{RestInterface.class},
                // 代理类实例
                proxyImpl);
    }
}

如果需要实现多个Rest接口的Bean提供,那么只需要:

    @Bean(name = "anotherRestInterface")
    public RestInterface createAnotherRestInterface() {
        return (AnotherRestInterface) Proxy.newProxyInstance(
                // 用于加载代理类的加载器,用接口的就好
                RestInterface.class.getClassLoader(),
                // 定义该代理类要实现几个接口
                new Class[]{AnotherRestInterface.class},
                // 代理类实例
                proxyImpl);
    }

/// ...

RestInterfaceImpl接口是线程安全的(不保存状态),所以多个Bean都可以共用这个代理实现。

完整代码

以下是RestInterfaceConfig完整代码:

package com.tydic.config.restconfig;

import com.alibaba.fastjson.JSONObject;
import com.tydic.business.RestInterface;
import com.tydic.constant.Constants;
import com.tydic.log.LogServiceImpl;
import com.tydic.utils.BaseUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.builder.ToStringBuilder;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.*;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import java.lang.annotation.Annotation;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.lang.reflect.Proxy;
import java.util.*;
import java.util.stream.Collectors;

/**
 * @author Gmw
 */
@Slf4j
@Configuration
public class RestInterfaceConfig {

    private final RestInterfaceImpl restInterface;

    public RestInterfaceConfig(
            @Value("${nginx.protocol}") String protocol,
            @Value("${nginx.host}") String host,
            @Value("${nginx.port}") String port,
            LogServiceImpl logService) {
        restInterface = new RestInterfaceImpl(protocol, host, port, logService);
    }

    @Bean(name = "restInterface")
    public RestInterface createRestInterface() {
        return (RestInterface) Proxy.newProxyInstance(
                // 用于加载代理类的加载器,用接口的就好
                RestInterface.class.getClassLoader(),
                // 定义该代理类要实现几个接口
                new Class[]{RestInterface.class},
                // 代理类实例
                restInterface);
    }

    private static class RestInterfaceImpl implements InvocationHandler {

        private final String restPrefix;

        private final LogServiceImpl logService;

        public RestInterfaceImpl(String protocol, String host, String port, LogServiceImpl logService) {
            this.logService = logService;
            restPrefix = String.format("%s://%s:%s", protocol, host, port);
        }

        @Override
        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
            try {
                RequestMetadata requestMetadata = getRequestMedataFromMethod(method, restPrefix);
                ParamWrapper paramWrapper = getParamFromMethodAndArgs(method.getParameters(), args);
                urlComplete(requestMetadata, paramWrapper.reqParams);
                RestTemplate restTemplate = BaseUtils.newRestTemplateInstance();
                ResponseEntity<String> response = restInvoke(restTemplate, requestMetadata, paramWrapper);
                return JSONObject.parseObject(response.getBody(), method.getReturnType());
            } catch (Exception e) {
                logService.writeErrorLog(e, method.getDeclaringClass(), method, args, Constants.KeyName.restInterface);
                throw e;
            }
        }

        private void urlComplete(RequestMetadata requestMetadata, Map<String, String> reqParams) {
            if (reqParams != null && !reqParams.isEmpty()) {
                String url = requestMetadata.getUrl() + ("?" + reqParams
                        .keySet().stream().map(key -> String.format("%s={%s}", key, key)).collect(Collectors.joining("&")));
                requestMetadata.setUrl(url);
            }
        }

        private ResponseEntity<String> restInvoke(RestTemplate rest,
                                                  RequestMetadata request,
                                                  ParamWrapper params) {
            log.debug("即将发送请求: ");
            log.debug("request: {}", request);
            log.debug("params: {}", params);
            if ("get".equals(request.getMethod())) {
                return rest.getForEntity(request.url, String.class, params.reqParams);
            }
            HttpEntity<Object> requestEntity = new HttpEntity<>(params.requestBody, request.getHeaders());
            return rest.exchange(request.url, HttpMethod.POST, requestEntity, String.class);
        }

        private ParamWrapper getParamFromMethodAndArgs(Parameter[] parameters, Object[] args) {
            ParamWrapper wrapper = new ParamWrapper();
            for (int index = 0; index < parameters.length; index++) {
                Parameter parameter = parameters[index];
                Object parameterValue = args[index];
                if (parameter.isAnnotationPresent(RequestBody.class)) {
                    wrapper.setRequestBodyType(parameter.getType());
                    if (parameter.getAnnotation(RequestBody.class).required() && parameterValue == null) {
                        // 请求体不能为空
                        throw new IllegalArgumentException("请求体不能为空");
                    }
                    wrapper.setRequestBody(parameterValue);
                } else if (parameter.isAnnotationPresent(PARAM.class)) {
                    PARAM paramAnno = parameter.getAnnotation(PARAM.class);
                    if (paramAnno.required() && parameterValue == null) {
                        // 请求参数不能为空
                        throw new IllegalArgumentException("请求参数[" + paramAnno.value() + "]不能为空");
                    }
                    wrapper.setReqParams(paramAnno.value(), parameterValue);
                }
            }
            return wrapper;
        }

        private RequestMetadata getRequestMedataFromMethod(Method method, String prefix) {
            RequestMetadata metadata = new RequestMetadata();
            List<Annotation> methodAnnotations = Arrays.asList(method.getAnnotations());
            metadata.setUrl(findUrl(methodAnnotations, prefix));
            metadata.setMethod(findMethod(methodAnnotations));
            metadata.setHeaders(generateCommonHttpHeaders());
            return metadata;
        }

        private String findMethod(List<Annotation> methodAnnotations) {
            Annotation reqPathAnno = methodAnnotations
                    .stream().filter(anno -> anno instanceof POST || anno instanceof GET)
                    .findFirst().orElseThrow(() -> new RuntimeException("未找到HTTP方法定义"));
            return reqPathAnno instanceof POST ? "post" : "get";
        }

        private String findUrl(List<Annotation> methodAnnotations, String prefix) {
            Annotation reqPathAnno = methodAnnotations
                    .stream().filter(anno -> anno instanceof POST || anno instanceof GET)
                    .findFirst().orElseThrow(() -> new RuntimeException("未找到HTTP路径定义"));
            String path;
            if (reqPathAnno instanceof POST) {
                path = ((POST) reqPathAnno).path();
            } else {
                path = ((GET) reqPathAnno).path();
            }
            return prefix + (path.startsWith("/") ? path : "/" + path);
        }

        private HttpHeaders generateCommonHttpHeaders() {
            HttpHeaders headers = new HttpHeaders();
            headers.setContentType(MediaType.APPLICATION_JSON_UTF8);
            headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON_UTF8));
            Optional.ofNullable(RequestContextHolder.getRequestAttributes())
                    .filter(requestAttributes -> ServletRequestAttributes.class.isAssignableFrom(requestAttributes.getClass()))
                    .map(requestAttributes -> ((ServletRequestAttributes) requestAttributes).getRequest())
                    .ifPresent(req -> {
                        headers.set(Constants.CustomHeader.WM_TENANT_ID, req.getHeader(Constants.CustomHeader.WM_TENANT_ID));
                        headers.set(Constants.CustomHeader.WM_APP_ID, req.getHeader(Constants.CustomHeader.WM_APP_ID));
                        headers.set(Constants.CustomHeader.WM_TOUCH_POINT, req.getHeader(Constants.CustomHeader.WM_TOUCH_POINT));
                    });
            return headers;
        }

        private static class RequestMetadata {
            private String url;
            private String method;
            private HttpHeaders headers;

            public String getUrl() {
                return url;
            }

            public void setUrl(String url) {
                this.url = url;
            }

            public String getMethod() {
                return method;
            }

            public void setMethod(String method) {
                this.method = method;
            }

            public HttpHeaders getHeaders() {
                return headers;
            }

            public void setHeaders(HttpHeaders headers) {
                this.headers = headers;
            }

            @Override
            public String toString() {
                return new ToStringBuilder(this)
                        .append("url", url)
                        .append("method", method)
                        .append("headers", headers)
                        .toString();
            }
        }

        private static class ParamWrapper {
            private Map<String, String> reqParams = new HashMap<>();
            private Class<?> requestBodyType;
            private Object requestBody;

            public Map<String, String> getReqParams() {
                return reqParams;
            }

            public void setReqParams(Map<String, String> reqParams) {
                this.reqParams = reqParams;
            }

            public Class<?> getRequestBodyType() {
                return requestBodyType;
            }

            public void setRequestBodyType(Class<?> requestBodyType) {
                this.requestBodyType = requestBodyType;
            }

            public Object getRequestBody() {
                return requestBody;
            }

            public void setRequestBody(Object requestBody) {
                this.requestBody = requestBody;
            }

            public void setReqParams(String paramName, Object paramValue) {
                if (paramValue != null) {
                    reqParams.put(paramName, paramValue.toString());
                }
            }

            @Override
            public String toString() {
                return new ToStringBuilder(this)
                        .append("reqParams", reqParams)
                        .append("requestBodyType", requestBodyType)
                        .append("requestBody", requestBody)
                        .toString();
            }
        }
    }

}

总结

如此便实现了统一的http管理,初版只支持JSON序列化,很多都是写死的比如头部的获取等等,而且请求方法也只支持GET、POST,且只有POST能携带请求体,只有GET能携带请求参数。后续找时间扩展这个玩意。