SpringBoot2 集成redis实现对api的限流保护

83 阅读2分钟

背景:最近项目有个需求,需要对注册到平台的api进行保护,当时第一反应借助SpringCloud进行限流,但是注册到平台的api是不固定的,没办法提前预知,也就意味着没办法提前配置好。这是想到利用自定义注解+拦截器+redis来实现对api的访问保护。接下来贴出我们的做法,大家参考一下。

spring:
  redis:
  ## redis集群配置的密码
    password: root
    cluster:
    ## 集群环境下支持的最大的重定向数
      max-redirects: 3
    ## redis集群节点
      nodes: 127.0.0.1:7001,127.0.0.1:7002,127.0.0.2:7001,127.0.0.2:7002,127.0.0.3:7001,127.0.0.3:7002
    lettuce:
      pool:
        max-idle: 10 # 连接池中的最大空闲连接
        max-wait: 500 # 连接池最大阻塞等待时间(使用负值表示没有限制)
        max-active: 8 # 连接池最大连接数(使用负值表示没有限制)
        min-idle: 0 # 连接池中的最小空闲连接
@Configuration
public class RedisConfig {

    @Bean
    public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory connectionFactory) {
        RedisTemplate<String, Object> template = new RedisTemplate<>();
        template.setConnectionFactory(connectionFactory);
        template.setKeySerializer(new StringRedisSerializer());
        template.setValueSerializer(new GenericJackson2JsonRedisSerializer());
        return template;
    }
}
@Aspect
@Component
@Slf4j
public class RateLimterAspect {

    @Resource
    private RedisTemplate redisTemplate;
    private DefaultRedisScript<Long> getRedisScript;
    @Pointcut("@annotation(com.szah.aop.annotation.RateLimiter)")
    public void rateLimiter(){}
    @Around("@annotation(rateLimiter)")
    public Object around(ProceedingJoinPoint proceedingJoinPoint, RateLimiter rateLimiter) throws Throwable {
        if(log.isDebugEnabled()){
            log.debug("RateLimiterAspect【分布式限流处理器】开始执行限流操作");
        }
        Signature signature = proceedingJoinPoint.getSignature();
        if(!(signature instanceof MethodSignature)){
            throw new IllegalAccessException("the Annotation @RateLimiter must used on method");
        }
        //限流模块key
        String limitKey = rateLimiter.key();
        Preconditions.checkNotNull(limitKey);
        //限流阈值
        long limitTimes = rateLimiter.limit();
        //限流时间
        long expireTime = rateLimiter.expire();
        if(log.isDebugEnabled()){
            log.debug("RateLimiterAspect【分布式限流处理器】参数值-limitTimes={},limitTimeout={}",limitTimes,expireTime);
        }
        // 限流提示语
        String message = rateLimiter.message();
        if(StringUtils.isBlank(message)){
            message = "false";
        }
        // 执行lua脚本
        List<String> keyList = new ArrayList<>();
        //设置key值为注解中的值
        keyList.add(limitKey);
        //调用脚本并执行
        Long result =(Long)redisTemplate.execute(getRedisScript, keyList, expireTime, limitTimes);
        if(Objects.nonNull(result) && result == 0){
            String msg = "由于超时单位时间=" +expireTime+ "-允许的请求次数=" + limitTimes+"[触发限流]";
            //log.debug(msg);
            throw new ReteLimitingException(message);
        }
        if(log.isDebugEnabled()){
            log.debug("RateLimiterAspect【分布式限流处理器】限流执行结果-result={},请求【正常】响应,",result);
        }
        return proceedingJoinPoint.proceed();
    }
    @PostConstruct
    public void init(){
        getRedisScript = new DefaultRedisScript<>();
        getRedisScript.setResultType(Long.class);
        getRedisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("script/rateLimiter.lua")));
        log.info("RateScript【分布式限流处理器】脚本加载完成");
    }
}
--获取KEY
local key1 = KEYS[1]

local val = redis.call('incr', key1)
local ttl = redis.call('ttl', key1)

--获取ARGV内的参数并打印
local expire = ARGV[1]
local times = ARGV[2]

redis.log(redis.LOG_DEBUG,tostring(times))
redis.log(redis.LOG_DEBUG,tostring(expire))

redis.log(redis.LOG_NOTICE, "incr "..key1.." "..val);
if val == 1 then
    redis.call('expire', key1, tonumber(expire))
else
    if ttl == -1 then
        redis.call('expire', key1, tonumber(expire))
    end
end

if val > tonumber(times) then
    return 0
end

return 1

最后还有一个自定义注解

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {

    /**
     * 限流key
     */
    String key() default "rate:limiter";
    /**
     * 单位时间限制通过请求数
     */
    long limit() default 10L;
    /**
     * 过期时间,单位秒
     */
    long expire() default 1L;
    /**
     * 返回值
     */
    String message() default "false";
}

接下来展示用法,通过下面代码可以发现该api一秒钟只能接收一次请求

@GetMapping("/test")
@RateLimiter(key = "ratedemo:testLateLimiter", limit = 1, expire = 1, message = "稍后再试")
public String test() throws Exception{
    System.out.println(log);
    return "test";
}