定义限流注解
@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Inherited
@Documented
public @interface RedisLimitAnnotation {
/**
* key
*/
String key() default "";
/**
* 一定时间内最多访问次数
*/
int count();
/**
* 设计时间,比如 5秒内只能访问count次
*/
int period();
}
配置类
@Component
public class RedisConfiguration {
@Bean
public DefaultRedisScript<Number> redisLuaScript() {
DefaultRedisScript<Number> redisScript = new DefaultRedisScript<>();
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("limit.lua")));
redisScript.setResultType(Number.class);
return redisScript;
}
@Bean("redisTemplate")
public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();
redisTemplate.setConnectionFactory(redisConnectionFactory);
Jackson2JsonRedisSerializer jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer(Object.class);
//设置value的序列化方式为JSOn
redisTemplate.setValueSerializer(jackson2JsonRedisSerializer);
//设置key的序列化方式为String
redisTemplate.setKeySerializer(new StringRedisSerializer());
redisTemplate.setHashKeySerializer(new StringRedisSerializer());
redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer);
redisTemplate.afterPropertiesSet();
return redisTemplate;
}
}
AOP拦截
@Aspect
@Configuration
public class LimitRestAspect {
private static final Logger logger = LoggerFactory.getLogger(LimitRestAspect.class);
@Autowired
private RedisTemplate<String, Object> redisTemplate;
@Autowired
private DefaultRedisScript<Number> redisluaScript;
@Pointcut(value = "@annotation(cn.ccb.demo.redis.annoation.RedisLimitAnnotation)")
public void rateLimit() {
}
@Around("rateLimit()")
public Object interceptor(ProceedingJoinPoint joinPoint) throws Throwable {
MethodSignature signature = (MethodSignature) joinPoint.getSignature();
Method method = signature.getMethod();
Class<?> targetClass = method.getDeclaringClass();
RedisLimitAnnotation rateLimit = method.getAnnotation(RedisLimitAnnotation.class);
if (rateLimit != null) {
HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
String ipAddress = getIpAddr(request);
StringBuffer stringBuffer = new StringBuffer();
stringBuffer.append(ipAddress).append("-")
.append(targetClass.getName()).append("- ")
.append(method.getName()).append("-")
.append(rateLimit.key());
List<String> keys = Collections.singletonList(stringBuffer.toString());
//调用lua脚本,获取返回结果,这里即为请求的次数
Number number = redisTemplate.execute(
redisluaScript,
keys,
rateLimit.count(),
rateLimit.period()
);
if (number != null && number.intValue() != 0 && number.intValue() <= rateLimit.count()) {
logger.info("限流时间段内访问了第:{} 次", number.toString());
return joinPoint.proceed();
}
} else {
return joinPoint.proceed(); // 没有限流注解
}
throw new RuntimeException("访问频率过快,被限流了");
}
/**
* 获取请求的IP方法
* @param request
* @return
*/
private static String getIpAddr(HttpServletRequest request) {
String ipAddress = null;
try {
ipAddress = request.getHeader("x-forwarded-for");
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getHeader("Proxy-Client-IP");
}
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getHeader("WL-Proxy-Client-IP");
}
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getRemoteAddr();
}
// 对于通过多个代理的情况,第一个IP为客户端真实IP,多个IP按照','分割
if (ipAddress != null && ipAddress.length() > 15) {
if (ipAddress.indexOf(",") > 0) {
ipAddress = ipAddress.substring(0, ipAddress.indexOf(","));
}
}
} catch (Exception e) {
ipAddress = "";
}
return ipAddress;
}
}
lua脚本
在resources目录下新建limit.lua
local key = "rate.limit:" .. KEYS[1]
local limit = tonumber(ARGV[1])
local current = tonumber(redis.call('get', key) or "0")
if current + 1 > limit then
return 0
else
-- 没有超阈值,将当前访问数量+1,并设置过期时间(可根据自己的业务情况调整)
redis.call("INCRBY", key,"1")
redis.call("expire", key,ARGV[2])
return current + 1
end
测试访问
@RestController
public class RedisController {
/**
* 10秒内只能访问一次
*/
@GetMapping("/redis/limit")
@RedisLimitAnnotation(key = "queryFromRedis", period = 10, count = 1)
public String queryFromRedis() {
return "success";
}
}