-原文链接:blog.csdn.net/qq_47614329…
流程原理: 定义一切面,通过@Prevent注解作为切入点、在该切面的前置通知获取该方法的所有入参并将其Base64编码,将入参Base64编码+完整方法名+(ip或者用户等)作为redis的key,入参作为reids的value,@Prevent的value作为redis的expire,存入redis; 每次进来这个切面根据入参Base64编码+完整方法名判断redis值是否存在,存在则拦截防刷,不存在则允许调用;
一、定义注解
package com.*.web.annotation;
import java.lang.annotation.*;
/**
* 接口防刷注解
*/
@Documented
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface Prevent {
/**
* 限制的时间值(秒)默认60s
*/
long value() default 60;
/**
* 限制规定时间内访问次数,默认只能访问一次
*/
long times() default 1;
/**
* 提示
*/
String message() default "";
/**
* 策略
*/
PreventStrategy strategy() default PreventStrategy.DEFAULT;
}
二、防刷策略
public enum PreventStrategy {
DEFAULT
}
三、定义切面
package com.zzyt.web.aop;
import com.*.web.annotation.Prevent;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import com.*.web.constants.PreventStrategy;
/**
* 防刷切面实现类
*/
@Aspect
@Component
public class PreventAop {
@Resource
private RedisTemplate<String, Long> redisTemplate;
final String redisKey = "PREVENT_METHOD_NAME:";
/**
* 切入点
*/
@Pointcut("@annotation(com.zzyt.web.annotation.Prevent)")
public void pointcut() {}
/**
* 处理前
*/
@Before("pointcut()")
public void joinPoint(JoinPoint joinPoint) throws Exception {
// 获取调用者ip
RequestAttributes requestAttributes = RequestContextHolder.currentRequestAttributes();
HttpServletRequest httpServletRequest = ((ServletRequestAttributes) requestAttributes).getRequest();
// String userIP = IpUtils.getUserIP(httpServletRequest);
String userIP = httpServletRequest.getRemoteHost();
// 获取调用接口方法名
MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
Method method = joinPoint.getTarget().getClass().getMethod(
methodSignature.getName(),
methodSignature.getParameterTypes()); // 获取该接口方法
String methodFullName = method.getDeclaringClass().getName() + method.getName(); // 获取到方法名
Prevent preventAnnotation = method.getAnnotation(Prevent.class); // 获取该接口上的prevent注解(为了使用该注解内的参数)
// 执行对应策略
entrance(preventAnnotation, userIP, methodFullName);
}
/**
* 通过prevent注册判断执行策略
* @param prevent 该接口的prevent注解对象
* @param userIP 访问该接口的用户ip
* @param methodFullName 该接口方法名
*/
private void entrance(Prevent prevent, String userIP, String methodFullName) throws Exception {
PreventStrategy strategy = prevent.strategy(); // 获取校验策略
if (Objects.requireNonNull(strategy) == PreventStrategy.DEFAULT) { // 默认就是default策略,执行default策略方法
defaultHandle(userIP, prevent, methodFullName);
} else {
throw new WrongDataException("无效的策略");
}
}
/**
* Default测试执行方法
* @param userIP 访问该接口的用户ip
* @param prevent 该接口的prevent注解对象
* @param methodFullName 该接口方法名
*/
private void defaultHandle(String userIP, Prevent prevent, String methodFullName) throws Exception {
String base64StrIP = toBase64String(userIP); // 加密用户ip(避免ip存在一些特殊字符作为redis的key不合法)
long expire = prevent.value(); // 获取访问限制时间
long times = prevent.times(); // 获取访问限制次数
// 限制特定时间内访问特定次数
long count = redisTemplate.opsForValue().increment(
redisKey + base64StrIP + ":" + methodFullName, 1); // 访问次数+1
if (count == 1) { // 如果访问次数为1,则重置访问限制时间(即redis超时时间)
redisTemplate.expire(
redisKey + base64StrIP + ":" + methodFullName,
expire,
TimeUnit.SECONDS);
}
if (count > times) { // 如果访问次数超出访问限制次数,则禁止访问
// 如果有限制信息则使用限制信息,没有则使用默认限制信息
String errorMessage =
!StringUtils.isEmpty(prevent.message()) ? prevent.message() : expire + "秒内不允许重复请求";
throw new WrongDataException( errorMessage);
}
}
/**
* 对象转换为base64字符串
* @param obj 对象值
* @return base64字符串
*/
private String toBase64String(String obj) throws Exception {
if (StringUtils.isEmpty(obj)) {
return null;
}
Base64.Encoder encoder = Base64.getEncoder();
byte[] bytes = obj.getBytes(StandardCharsets.UTF_8);
return encoder.encodeToString(bytes);
}
}
四、测试
package com.*.web.controller;
@RestController
@RequestMapping("/warningNotification")
public class TestController {
@Prevent(value = 60,times = 5)
@PostMapping("/test")
public ResponseData test(HttpServletRequest request) {
return ResponseData.success();
}
}