在SpirngBoot项目中使用HandlerInterceptor防止接口重复提交

191 阅读1分钟

背景

前端请求后端接口的时候,特别是针对修改数据的场景,重新提交会给数据带来污染,当然,关键接口一定要保证幂等性,此方案为过滤一部分场景的重复提交。

方案

使用SpringBoot的Interceptor来拦截方法,使用redis来过滤短时间内的重复提交。

代码

自定义注解

在需要防重复提交的方法上加上自定义注解,进行拦截。

@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface PreventDoubleCommits {

  /**
   * 用来判断是否重复提交的唯一key
   * @return
   */
  String commitKey() default "";

  /**
   * 过期时间,默认1,单位秒
   * 设置为1 :表示一个唯一key在1秒内只能提交1次
   * 设置为2 :表示一个唯一key在2秒内只能提交1次
   * 以此类推
   * @return
   */
  int expireTime() default 1;

}

实现HandlerInterceptor

@Slf4j
public class PreventDoubleCommitsInterceptor implements HandlerInterceptor {

  private static final String PRE_KEY = "XXXXX-PDC-";

  private RedisTemplate<String, String> stringRedisTemplate;

  public PreventDoubleCommitsInterceptor(RedisTemplate<String, String> stringRedisTemplate) {
    this.stringRedisTemplate = stringRedisTemplate;
  }

  static String responseString;

  static {
    RespBody respBody = new RespBody<>();
    respBody.setErrorInfoInterface(
        new CustomErrorInfo("-1", "手速太快了,请不要重复提交", "000000"));
    responseString = JSON.toJSONString(respBody);
  }


  @Override
  public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler)
      throws Exception {
    if (handler instanceof HandlerMethod) {
      log.info("preventDoubleCommitsInterceptor into");
      return isDoubleCommits(request, response, (HandlerMethod) handler);
    } else {
      return HandlerInterceptor.super.preHandle(request, response, handler);
    }
  }

  private boolean isDoubleCommits(HttpServletRequest request, HttpServletResponse response,
      HandlerMethod handler) throws IOException {
    Method method = handler.getMethod();
    PreventDoubleCommits doubleCommits = method.getAnnotation(PreventDoubleCommits.class);
    if (Objects.isNull(doubleCommits)) {
      return true;
    }
    if (doCheck(request, doubleCommits)) {
      log.info("preventDoubleCommitsInterceptor result:Repeat submission");
      response.getOutputStream().write(responseString.getBytes(StandardCharsets.UTF_8));
      return false;
    }
    return true;
  }

  private boolean doCheck(HttpServletRequest request, PreventDoubleCommits doubleCommits)
      throws IOException {
    String unniKey = doubleCommits.commitKey();
    if (StringUtils.isBlank(unniKey)) {
      unniKey = createUnniKey(request);
    }
    unniKey = PRE_KEY + unniKey;
    return !Boolean.TRUE.equals(stringRedisTemplate.opsForValue()
        .setIfAbsent(unniKey, "-1", doubleCommits.expireTime(), TimeUnit.SECONDS));
  }

  private String createUnniKey(HttpServletRequest request) throws IOException {
    String requestUrl = request.getRequestURI();
    byte[] bytes = StreamUtils.copyToByteArray(request.getInputStream());
    String jsonString = new String(bytes, request.getCharacterEncoding());
    return MD5.create().digestHex16(requestUrl + jsonString);
  }
}

配置 WebMvcConfig

建议实现WebMvcConfigurer,而不是使用WebMvcConfigurationSupport,具体原因可以参考这里

@Configuration
public class WebMvcConfig implements WebMvcConfigurer {

  @Autowired
  private RedisTemplate<String, String> stringRedisTemplate;

  @Override
  public void addInterceptors(InterceptorRegistry registry) {
    registry.addInterceptor(new PreventDoubleCommitsInterceptor(stringRedisTemplate))
        .addPathPatterns("/**");
  }

  @Bean
  @Qualifier(DispatcherServletAutoConfiguration.DEFAULT_DISPATCHER_SERVLET_BEAN_NAME)
  public DispatcherServlet dispatcherServlet() {
    return new DispatcherServlet() {
      @Override
      protected void doDispatch(HttpServletRequest request, HttpServletResponse response)
          throws Exception {
        super.doDispatch(new RequestWrapper(request), response);
      }
    };
  }
}

重写Request

public class RequestWrapper extends HttpServletRequestWrapper {


  private byte[] body;

  public RequestWrapper(HttpServletRequest request) throws IOException {
    super(request);
    body = StreamUtils.copyToByteArray(request.getInputStream());
  }

  @Override
  public ServletInputStream getInputStream() throws IOException {
    InputStream bodyStream = new ByteArrayInputStream(body);
    return new ServletInputStream() {

      @Override
      public int read() throws IOException {
        return bodyStream.read();
      }

      @Override
      public boolean isFinished() {
        return false;
      }

      @Override
      public boolean isReady() {
        return true;
      }

      @Override
      public void setReadListener(ReadListener readListener) {

      }
    };
  }

  @Override
  public BufferedReader getReader() throws IOException {
    return new BufferedReader(new InputStreamReader(getInputStream()));
  }
}

说明

  • 必须重写HttpServletRequestWrapper获取InputStream的方式,解决流只能读取一次的问题

优化

  • 拦截方法读取参数的时候,可以动态读取用户指定的参数,使用spring-el表达式${}解析